shopify_sdk/auth/oauth/state.rs
1//! State parameter handling for OAuth CSRF protection.
2//!
3//! This module provides the [`StateParam`] type for generating and parsing
4//! OAuth state parameters used for CSRF protection during the authorization flow.
5//!
6//! # Overview
7//!
8//! The state parameter serves two purposes in OAuth:
9//! 1. **CSRF Protection**: Prevents cross-site request forgery attacks by ensuring
10//! the callback was initiated by a legitimate authorization request.
11//! 2. **Data Preservation**: Optionally carries custom data through the OAuth flow.
12//!
13//! # Formats
14//!
15//! `StateParam` supports three usage patterns:
16//!
17//! - **Simple nonce**: Generated via [`StateParam::new()`], a 15-character
18//! alphanumeric string for basic CSRF protection.
19//! - **Structured with data**: Generated via [`StateParam::with_data()`], embeds
20//! a nonce and custom JSON data in a base64-encoded string.
21//! - **Raw string**: Created via [`StateParam::from_raw()`], wraps an arbitrary
22//! string for advanced use cases.
23//!
24//! # Example
25//!
26//! ```rust
27//! use shopify_sdk::auth::oauth::StateParam;
28//! use serde::{Serialize, Deserialize};
29//!
30//! // Simple CSRF protection
31//! let state = StateParam::new();
32//! assert_eq!(state.nonce().len(), 15);
33//!
34//! // Embed custom data through the flow
35//! #[derive(Serialize, Deserialize, Debug, PartialEq)]
36//! struct FlowData {
37//! return_url: String,
38//! }
39//!
40//! let data = FlowData { return_url: "/dashboard".to_string() };
41//! let state = StateParam::with_data(&data);
42//! let extracted: Option<FlowData> = state.extract_data();
43//! assert_eq!(extracted.unwrap().return_url, "/dashboard");
44//! ```
45
46use base64::prelude::*;
47use rand::distributions::Alphanumeric;
48use rand::Rng;
49use serde::{de::DeserializeOwned, Deserialize, Serialize};
50use std::fmt;
51
52/// OAuth state parameter for CSRF protection and data preservation.
53///
54/// This type represents the state parameter used in OAuth authorization flows.
55/// It provides cryptographically secure nonce generation and optional data embedding.
56///
57/// # Thread Safety
58///
59/// `StateParam` is `Send + Sync`, making it safe to share across threads.
60///
61/// # Example
62///
63/// ```rust
64/// use shopify_sdk::auth::oauth::StateParam;
65///
66/// // Generate a simple state for CSRF protection
67/// let state = StateParam::new();
68/// println!("State: {}", state);
69/// println!("Nonce: {}", state.nonce());
70///
71/// // Use as_ref for URL encoding
72/// let encoded = urlencoding::encode(state.as_ref());
73/// ```
74#[derive(Clone, Debug, PartialEq, Eq)]
75pub struct StateParam {
76 /// The full state string value.
77 value: String,
78 /// Whether this is a structured state (base64 JSON with nonce and data).
79 is_structured: bool,
80}
81
82/// Internal structure for structured state parameters.
83#[derive(Serialize, Deserialize)]
84struct StructuredState<T> {
85 nonce: String,
86 data: T,
87}
88
89/// Internal structure for extracting just the nonce.
90#[derive(Deserialize)]
91struct NonceOnly {
92 nonce: String,
93}
94
95// Verify StateParam is Send + Sync at compile time
96const _: fn() = || {
97 const fn assert_send_sync<T: Send + Sync>() {}
98 assert_send_sync::<StateParam>();
99};
100
101impl StateParam {
102 /// The length of generated nonces (matches Ruby SDK's SecureRandom.alphanumeric(15)).
103 const NONCE_LENGTH: usize = 15;
104
105 /// Creates a new state parameter with a cryptographically secure random nonce.
106 ///
107 /// The nonce is a 15-character alphanumeric string generated using a
108 /// cryptographically secure random number generator, matching the Ruby SDK's
109 /// `SecureRandom.alphanumeric(15)` behavior.
110 ///
111 /// # Example
112 ///
113 /// ```rust
114 /// use shopify_sdk::auth::oauth::StateParam;
115 ///
116 /// let state = StateParam::new();
117 /// assert_eq!(state.nonce().len(), 15);
118 /// assert!(state.nonce().chars().all(|c| c.is_ascii_alphanumeric()));
119 /// ```
120 #[must_use]
121 pub fn new() -> Self {
122 let nonce: String = rand::thread_rng()
123 .sample_iter(&Alphanumeric)
124 .take(Self::NONCE_LENGTH)
125 .map(char::from)
126 .collect();
127
128 Self {
129 value: nonce,
130 is_structured: false,
131 }
132 }
133
134 /// Creates a state parameter with embedded custom data.
135 ///
136 /// The state is created as a base64-encoded JSON object containing both
137 /// a secure random nonce and the provided data. This allows passing custom
138 /// information through the OAuth flow (e.g., a return URL).
139 ///
140 /// # Arguments
141 ///
142 /// * `data` - Any serializable data to embed in the state
143 ///
144 /// # Example
145 ///
146 /// ```rust
147 /// use shopify_sdk::auth::oauth::StateParam;
148 /// use serde::{Serialize, Deserialize};
149 ///
150 /// #[derive(Serialize, Deserialize)]
151 /// struct FlowContext {
152 /// return_url: String,
153 /// user_id: u64,
154 /// }
155 ///
156 /// let context = FlowContext {
157 /// return_url: "/dashboard".to_string(),
158 /// user_id: 12345,
159 /// };
160 ///
161 /// let state = StateParam::with_data(&context);
162 /// // State is base64-encoded, can be safely used in URLs
163 /// println!("State for OAuth: {}", state);
164 /// ```
165 #[must_use]
166 pub fn with_data<T: Serialize>(data: &T) -> Self {
167 let nonce: String = rand::thread_rng()
168 .sample_iter(&Alphanumeric)
169 .take(Self::NONCE_LENGTH)
170 .map(char::from)
171 .collect();
172
173 let structured = StructuredState { nonce, data };
174 let json = serde_json::to_string(&structured).unwrap_or_default();
175 let encoded = BASE64_STANDARD.encode(json.as_bytes());
176
177 Self {
178 value: encoded,
179 is_structured: true,
180 }
181 }
182
183 /// Creates a state parameter from a raw string.
184 ///
185 /// This allows advanced users to provide their own state value. The string
186 /// is used as-is without any processing or validation.
187 ///
188 /// # Arguments
189 ///
190 /// * `raw` - The raw state string to use
191 ///
192 /// # Example
193 ///
194 /// ```rust
195 /// use shopify_sdk::auth::oauth::StateParam;
196 ///
197 /// let state = StateParam::from_raw("custom-state-value");
198 /// assert_eq!(state.as_ref(), "custom-state-value");
199 /// ```
200 #[must_use]
201 pub fn from_raw(raw: impl Into<String>) -> Self {
202 Self {
203 value: raw.into(),
204 is_structured: false,
205 }
206 }
207
208 /// Returns the raw state value.
209 ///
210 /// For simple states (created with `new()` or `from_raw()`), this returns
211 /// the nonce or raw value directly. For structured states (created with
212 /// `with_data()`), this returns the full base64-encoded value.
213 ///
214 /// To extract the actual nonce from a structured state, use [`extract_nonce()`].
215 ///
216 /// # Example
217 ///
218 /// ```rust
219 /// use shopify_sdk::auth::oauth::StateParam;
220 ///
221 /// // Simple state - nonce() returns the 15-char nonce
222 /// let simple = StateParam::new();
223 /// assert_eq!(simple.nonce().len(), 15);
224 ///
225 /// // Structured state - nonce() returns the full encoded value
226 /// // Use extract_nonce() to get the actual nonce
227 /// let structured = StateParam::with_data(&"test");
228 /// let actual_nonce = structured.extract_nonce();
229 /// assert_eq!(actual_nonce.len(), 15);
230 /// ```
231 ///
232 /// [`extract_nonce()`]: Self::extract_nonce
233 #[must_use]
234 pub fn nonce(&self) -> &str {
235 &self.value
236 }
237
238 /// Extracts the embedded data from a structured state.
239 ///
240 /// Attempts to base64-decode the state, parse it as JSON, and deserialize
241 /// the `data` field to the specified type.
242 ///
243 /// # Returns
244 ///
245 /// - `Some(T)` if the state was structured and the data could be deserialized
246 /// - `None` if the state is not structured or deserialization fails
247 ///
248 /// # Example
249 ///
250 /// ```rust
251 /// use shopify_sdk::auth::oauth::StateParam;
252 /// use serde::{Serialize, Deserialize};
253 ///
254 /// #[derive(Serialize, Deserialize, Debug, PartialEq)]
255 /// struct UserData {
256 /// name: String,
257 /// }
258 ///
259 /// let data = UserData { name: "Alice".to_string() };
260 /// let state = StateParam::with_data(&data);
261 ///
262 /// let extracted: Option<UserData> = state.extract_data();
263 /// assert_eq!(extracted, Some(data));
264 ///
265 /// // Simple states don't have embedded data
266 /// let simple = StateParam::new();
267 /// let extracted: Option<UserData> = simple.extract_data();
268 /// assert!(extracted.is_none());
269 /// ```
270 #[must_use]
271 pub fn extract_data<T: DeserializeOwned>(&self) -> Option<T> {
272 // Attempt to decode base64
273 let decoded = BASE64_STANDARD.decode(self.value.as_bytes()).ok()?;
274 let json_str = String::from_utf8(decoded).ok()?;
275
276 // Parse as structured state
277 let structured: StructuredState<T> = serde_json::from_str(&json_str).ok()?;
278 Some(structured.data)
279 }
280
281 /// Extracts the nonce from a potentially structured state.
282 ///
283 /// For simple states, returns the full value (which is the nonce).
284 /// For structured states, decodes and extracts the actual nonce.
285 ///
286 /// # Returns
287 ///
288 /// The 15-character nonce string, or the full value if parsing fails.
289 ///
290 /// # Example
291 ///
292 /// ```rust
293 /// use shopify_sdk::auth::oauth::StateParam;
294 ///
295 /// // Simple state
296 /// let simple = StateParam::new();
297 /// assert_eq!(simple.extract_nonce().len(), 15);
298 ///
299 /// // Structured state
300 /// let structured = StateParam::with_data(&42);
301 /// assert_eq!(structured.extract_nonce().len(), 15);
302 /// ```
303 #[must_use]
304 pub fn extract_nonce(&self) -> String {
305 if !self.is_structured {
306 return self.value.clone();
307 }
308
309 // Try to decode and extract nonce
310 if let Ok(decoded) = BASE64_STANDARD.decode(self.value.as_bytes()) {
311 if let Ok(json_str) = String::from_utf8(decoded) {
312 if let Ok(nonce_only) = serde_json::from_str::<NonceOnly>(&json_str) {
313 return nonce_only.nonce;
314 }
315 }
316 }
317
318 // Fallback to full value if parsing fails
319 self.value.clone()
320 }
321}
322
323impl Default for StateParam {
324 fn default() -> Self {
325 Self::new()
326 }
327}
328
329impl fmt::Display for StateParam {
330 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
331 write!(f, "{}", self.value)
332 }
333}
334
335impl AsRef<str> for StateParam {
336 fn as_ref(&self) -> &str {
337 &self.value
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use serde::{Deserialize, Serialize};
345
346 #[test]
347 fn test_new_generates_15_char_alphanumeric_nonce() {
348 let state = StateParam::new();
349 let nonce = state.nonce();
350
351 assert_eq!(nonce.len(), 15);
352 assert!(nonce.chars().all(|c| c.is_ascii_alphanumeric()));
353 }
354
355 #[test]
356 fn test_new_generates_unique_nonces() {
357 let state1 = StateParam::new();
358 let state2 = StateParam::new();
359
360 // Extremely unlikely to generate the same nonce twice
361 assert_ne!(state1.nonce(), state2.nonce());
362 }
363
364 #[test]
365 fn test_with_data_embeds_json_in_base64() {
366 #[derive(Serialize, Deserialize, Debug, PartialEq)]
367 struct TestData {
368 return_url: String,
369 }
370
371 let data = TestData {
372 return_url: "/dashboard".to_string(),
373 };
374 let state = StateParam::with_data(&data);
375
376 // Should be valid base64
377 let decoded = BASE64_STANDARD.decode(state.as_ref().as_bytes());
378 assert!(decoded.is_ok());
379
380 // Should contain valid JSON
381 let json_str = String::from_utf8(decoded.unwrap()).unwrap();
382 assert!(json_str.contains("nonce"));
383 assert!(json_str.contains("data"));
384 assert!(json_str.contains("/dashboard"));
385 }
386
387 #[test]
388 fn test_from_raw_wraps_string_correctly() {
389 let state = StateParam::from_raw("custom-state-123");
390 assert_eq!(state.as_ref(), "custom-state-123");
391 assert_eq!(state.nonce(), "custom-state-123");
392 }
393
394 #[test]
395 fn test_nonce_returns_value_for_simple_state() {
396 let state = StateParam::new();
397 assert_eq!(state.nonce().len(), 15);
398 }
399
400 #[test]
401 fn test_nonce_returns_full_value_for_structured_state() {
402 let state = StateParam::with_data(&"test");
403 // For structured state, nonce() returns the full base64 value
404 // Use extract_nonce() to get the actual nonce
405 assert!(state.nonce().len() > 15);
406 assert_eq!(state.extract_nonce().len(), 15);
407 }
408
409 #[test]
410 fn test_extract_data_returns_embedded_data() {
411 #[derive(Serialize, Deserialize, Debug, PartialEq)]
412 struct FlowData {
413 user_id: u64,
414 redirect_to: String,
415 }
416
417 let original = FlowData {
418 user_id: 12345,
419 redirect_to: "/admin/orders".to_string(),
420 };
421 let state = StateParam::with_data(&original);
422
423 let extracted: Option<FlowData> = state.extract_data();
424 assert_eq!(extracted, Some(original));
425 }
426
427 #[test]
428 fn test_extract_data_returns_none_for_simple_state() {
429 #[derive(Deserialize)]
430 struct SomeData {
431 #[allow(dead_code)]
432 field: String,
433 }
434
435 let state = StateParam::new();
436 let extracted: Option<SomeData> = state.extract_data();
437 assert!(extracted.is_none());
438 }
439
440 #[test]
441 fn test_extract_data_returns_none_for_type_mismatch() {
442 #[derive(Serialize)]
443 struct DataA {
444 field_a: String,
445 }
446
447 #[derive(Deserialize)]
448 struct DataB {
449 #[allow(dead_code)]
450 field_b: i32,
451 }
452
453 let data = DataA {
454 field_a: "test".to_string(),
455 };
456 let state = StateParam::with_data(&data);
457
458 let extracted: Option<DataB> = state.extract_data();
459 assert!(extracted.is_none());
460 }
461
462 #[test]
463 fn test_display_returns_full_state_string() {
464 let state = StateParam::from_raw("display-test");
465 assert_eq!(format!("{}", state), "display-test");
466
467 let state = StateParam::new();
468 assert_eq!(format!("{}", state), state.as_ref());
469 }
470
471 #[test]
472 fn test_as_ref_provides_string_slice() {
473 let state = StateParam::from_raw("ref-test");
474 let s: &str = state.as_ref();
475 assert_eq!(s, "ref-test");
476 }
477
478 #[test]
479 fn test_with_data_handles_various_types() {
480 // String
481 let state = StateParam::with_data(&"simple string");
482 let extracted: Option<String> = state.extract_data();
483 assert_eq!(extracted, Some("simple string".to_string()));
484
485 // Number
486 let state = StateParam::with_data(&42i32);
487 let extracted: Option<i32> = state.extract_data();
488 assert_eq!(extracted, Some(42));
489
490 // Vec
491 let state = StateParam::with_data(&vec![1, 2, 3]);
492 let extracted: Option<Vec<i32>> = state.extract_data();
493 assert_eq!(extracted, Some(vec![1, 2, 3]));
494 }
495
496 #[test]
497 fn test_extract_nonce_from_structured_state() {
498 #[derive(Serialize)]
499 struct Data {
500 value: i32,
501 }
502
503 let state = StateParam::with_data(&Data { value: 42 });
504 let nonce = state.extract_nonce();
505
506 // Nonce should be 15 alphanumeric characters
507 assert_eq!(nonce.len(), 15);
508 assert!(nonce.chars().all(|c| c.is_ascii_alphanumeric()));
509 }
510
511 #[test]
512 fn test_state_param_is_send_sync() {
513 fn assert_send_sync<T: Send + Sync>() {}
514 assert_send_sync::<StateParam>();
515 }
516
517 #[test]
518 fn test_state_param_clone() {
519 let state = StateParam::new();
520 let cloned = state.clone();
521 assert_eq!(state, cloned);
522 }
523
524 #[test]
525 fn test_state_param_eq() {
526 let state1 = StateParam::from_raw("same");
527 let state2 = StateParam::from_raw("same");
528 let state3 = StateParam::from_raw("different");
529
530 assert_eq!(state1, state2);
531 assert_ne!(state1, state3);
532 }
533
534 #[test]
535 fn test_state_param_default() {
536 let state = StateParam::default();
537 assert_eq!(state.nonce().len(), 15);
538 }
539}