rustauth_core/auth/oauth/
state.rs1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use sha2::{Digest, Sha256};
4use time::{Duration, OffsetDateTime};
5
6use crate::context::AuthContext;
7use crate::crypto::random::generate_random_string;
8use crate::db::DbAdapter;
9use crate::error::RustAuthError;
10use crate::options::OAuthStateStoreStrategy;
11use crate::verification::{CreateVerificationInput, DbVerificationStore};
12
13use super::tokens::{decrypt_with_context, encrypt_with_context};
14
15fn verification_store<'a>(
16 context: &'a AuthContext,
17 adapter: &'a dyn DbAdapter,
18) -> DbVerificationStore<'a> {
19 DbVerificationStore::with_options(
20 adapter,
21 context.db_schema.clone(),
22 context.options.verification.clone(),
23 )
24}
25
26#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
27pub struct OAuthStateLink {
28 pub email: String,
29 pub user_id: String,
30}
31
32#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
33pub struct OAuthStateData {
34 pub callback_url: String,
35 pub code_verifier: String,
36 pub oauth_state: String,
37 pub error_url: Option<String>,
38 pub new_user_url: Option<String>,
39 pub link: Option<OAuthStateLink>,
40 pub expires_at: OffsetDateTime,
41 pub request_sign_up: bool,
42 pub additional_data: Value,
43}
44
45#[derive(Debug, Clone, PartialEq)]
46pub struct OAuthStateInput {
47 pub callback_url: String,
48 pub error_url: Option<String>,
49 pub new_user_url: Option<String>,
50 pub link: Option<OAuthStateLink>,
51 pub request_sign_up: bool,
52 pub additional_data: Value,
53 pub expires_at: Option<OffsetDateTime>,
54}
55
56impl Default for OAuthStateInput {
57 fn default() -> Self {
58 Self {
59 callback_url: String::new(),
60 error_url: None,
61 new_user_url: None,
62 link: None,
63 request_sign_up: false,
64 additional_data: Value::Null,
65 expires_at: None,
66 }
67 }
68}
69
70#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
71pub struct GeneratedOAuthState {
72 pub state: String,
73 pub data: OAuthStateData,
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub struct OAuthStateParseInput<'a> {
78 pub state: &'a str,
79 pub oauth_state: Option<&'a str>,
80 pub skip_state_cookie_check: bool,
81}
82
83pub async fn generate_oauth_state(
84 context: &AuthContext,
85 adapter: Option<&dyn DbAdapter>,
86 input: OAuthStateInput,
87) -> Result<GeneratedOAuthState, RustAuthError> {
88 if input.callback_url.is_empty() {
89 return Err(RustAuthError::Api("callback URL is required".to_owned()));
90 }
91 let data = OAuthStateData {
92 callback_url: input.callback_url,
93 code_verifier: generate_random_string(128),
94 oauth_state: generate_random_string(32),
95 error_url: input.error_url,
96 new_user_url: input.new_user_url,
97 link: input.link,
98 expires_at: input
99 .expires_at
100 .unwrap_or_else(|| OffsetDateTime::now_utc() + Duration::minutes(10)),
101 request_sign_up: input.request_sign_up,
102 additional_data: input.additional_data,
103 };
104 let state = match context.options.account.store_state_strategy {
105 OAuthStateStoreStrategy::Cookie => {
106 let json = serde_json::to_string(&data)
107 .map_err(|error| RustAuthError::Crypto(error.to_string()))?;
108 let encrypted = encrypt_with_context(&json, context)?;
109 if let Some(adapter) = adapter {
115 verification_store(context, adapter)
116 .create_verification(CreateVerificationInput::new(
117 cookie_state_single_use_identifier(&encrypted),
118 String::new(),
119 data.expires_at,
120 ))
121 .await?;
122 }
123 encrypted
124 }
125 OAuthStateStoreStrategy::Database => {
126 let adapter = adapter.ok_or_else(|| {
127 RustAuthError::Adapter("database OAuth state requires an adapter".to_owned())
128 })?;
129 let state = generate_random_string(32);
130 let json = serde_json::to_string(&data)
131 .map_err(|error| RustAuthError::Crypto(error.to_string()))?;
132 verification_store(context, adapter)
133 .create_verification(CreateVerificationInput::new(
134 oauth_state_identifier(&state),
135 json,
136 data.expires_at,
137 ))
138 .await?;
139 state
140 }
141 };
142 Ok(GeneratedOAuthState { state, data })
143}
144
145pub async fn parse_oauth_state(
146 context: &AuthContext,
147 adapter: Option<&dyn DbAdapter>,
148 state: &str,
149) -> Result<OAuthStateData, RustAuthError> {
150 parse_oauth_state_with_input(
151 context,
152 adapter,
153 OAuthStateParseInput {
154 state,
155 oauth_state: None,
156 skip_state_cookie_check: true,
157 },
158 )
159 .await
160}
161
162pub async fn parse_oauth_state_with_input(
163 context: &AuthContext,
164 adapter: Option<&dyn DbAdapter>,
165 input: OAuthStateParseInput<'_>,
166) -> Result<OAuthStateData, RustAuthError> {
167 let state = input.state;
168 let data = match context.options.account.store_state_strategy {
169 OAuthStateStoreStrategy::Cookie => {
170 if let Some(adapter) = adapter {
176 let verifications = verification_store(context, adapter);
177 let identifier = cookie_state_single_use_identifier(state);
178 if verifications
179 .consume_verification_including_expired(&identifier)
180 .await?
181 .is_none()
182 {
183 return Err(RustAuthError::Api("invalid OAuth state".to_owned()));
184 }
185 }
186 let json = decrypt_with_context(state, context)?;
187 serde_json::from_str::<OAuthStateData>(&json)
188 .map_err(|error| RustAuthError::Crypto(error.to_string()))?
189 }
190 OAuthStateStoreStrategy::Database => {
191 let adapter = adapter.ok_or_else(|| {
192 RustAuthError::Adapter("database OAuth state requires an adapter".to_owned())
193 })?;
194 let verifications = verification_store(context, adapter);
195 let identifier = oauth_state_identifier(state);
196 let verification = verifications
197 .consume_verification_including_expired(&identifier)
198 .await?
199 .ok_or_else(|| RustAuthError::Api("invalid OAuth state".to_owned()))?;
200 serde_json::from_str::<OAuthStateData>(&verification.value)
201 .map_err(|error| RustAuthError::Crypto(error.to_string()))?
202 }
203 };
204 if data.expires_at <= OffsetDateTime::now_utc() {
205 return Err(RustAuthError::Api("OAuth state expired".to_owned()));
206 }
207 if !input.skip_state_cookie_check && input.oauth_state != Some(data.oauth_state.as_str()) {
208 return Err(RustAuthError::Api("invalid OAuth state".to_owned()));
209 }
210 Ok(data)
211}
212
213pub fn oauth_state_identifier(state: &str) -> String {
214 format!("oauth-state-{state}")
215}
216
217fn cookie_state_single_use_identifier(encrypted_state: &str) -> String {
224 let digest = Sha256::digest(encrypted_state.as_bytes());
225 format!("oauth-state-cookie-{}", hex::encode(digest))
226}