Skip to main content

shopify_sdk/auth/oauth/
state.rs

1//! State parameter handling for OAuth CSRF protection.
2//!
3//! This module provides the [`StateParam`] type for generating and parsing
4//! OAuth state parameters used for CSRF protection during the authorization flow.
5//!
6//! # Overview
7//!
8//! The state parameter serves two purposes in OAuth:
9//! 1. **CSRF Protection**: Prevents cross-site request forgery attacks by ensuring
10//!    the callback was initiated by a legitimate authorization request.
11//! 2. **Data Preservation**: Optionally carries custom data through the OAuth flow.
12//!
13//! # Formats
14//!
15//! `StateParam` supports three usage patterns:
16//!
17//! - **Simple nonce**: Generated via [`StateParam::new()`], a 15-character
18//!   alphanumeric string for basic CSRF protection.
19//! - **Structured with data**: Generated via [`StateParam::with_data()`], embeds
20//!   a nonce and custom JSON data in a base64-encoded string.
21//! - **Raw string**: Created via [`StateParam::from_raw()`], wraps an arbitrary
22//!   string for advanced use cases.
23//!
24//! # Example
25//!
26//! ```rust
27//! use shopify_sdk::auth::oauth::StateParam;
28//! use serde::{Serialize, Deserialize};
29//!
30//! // Simple CSRF protection
31//! let state = StateParam::new();
32//! assert_eq!(state.nonce().len(), 15);
33//!
34//! // Embed custom data through the flow
35//! #[derive(Serialize, Deserialize, Debug, PartialEq)]
36//! struct FlowData {
37//!     return_url: String,
38//! }
39//!
40//! let data = FlowData { return_url: "/dashboard".to_string() };
41//! let state = StateParam::with_data(&data);
42//! let extracted: Option<FlowData> = state.extract_data();
43//! assert_eq!(extracted.unwrap().return_url, "/dashboard");
44//! ```
45
46use base64::prelude::*;
47use rand::distributions::Alphanumeric;
48use rand::Rng;
49use serde::{de::DeserializeOwned, Deserialize, Serialize};
50use std::fmt;
51
52/// OAuth state parameter for CSRF protection and data preservation.
53///
54/// This type represents the state parameter used in OAuth authorization flows.
55/// It provides cryptographically secure nonce generation and optional data embedding.
56///
57/// # Thread Safety
58///
59/// `StateParam` is `Send + Sync`, making it safe to share across threads.
60///
61/// # Example
62///
63/// ```rust
64/// use shopify_sdk::auth::oauth::StateParam;
65///
66/// // Generate a simple state for CSRF protection
67/// let state = StateParam::new();
68/// println!("State: {}", state);
69/// println!("Nonce: {}", state.nonce());
70///
71/// // Use as_ref for URL encoding
72/// let encoded = urlencoding::encode(state.as_ref());
73/// ```
74#[derive(Clone, Debug, PartialEq, Eq)]
75pub struct StateParam {
76    /// The full state string value.
77    value: String,
78    /// Whether this is a structured state (base64 JSON with nonce and data).
79    is_structured: bool,
80}
81
82/// Internal structure for structured state parameters.
83#[derive(Serialize, Deserialize)]
84struct StructuredState<T> {
85    nonce: String,
86    data: T,
87}
88
89/// Internal structure for extracting just the nonce.
90#[derive(Deserialize)]
91struct NonceOnly {
92    nonce: String,
93}
94
95// Verify StateParam is Send + Sync at compile time
96const _: fn() = || {
97    const fn assert_send_sync<T: Send + Sync>() {}
98    assert_send_sync::<StateParam>();
99};
100
101impl StateParam {
102    /// The length of generated nonces (matches Ruby SDK's SecureRandom.alphanumeric(15)).
103    const NONCE_LENGTH: usize = 15;
104
105    /// Creates a new state parameter with a cryptographically secure random nonce.
106    ///
107    /// The nonce is a 15-character alphanumeric string generated using a
108    /// cryptographically secure random number generator, matching the Ruby SDK's
109    /// `SecureRandom.alphanumeric(15)` behavior.
110    ///
111    /// # Example
112    ///
113    /// ```rust
114    /// use shopify_sdk::auth::oauth::StateParam;
115    ///
116    /// let state = StateParam::new();
117    /// assert_eq!(state.nonce().len(), 15);
118    /// assert!(state.nonce().chars().all(|c| c.is_ascii_alphanumeric()));
119    /// ```
120    #[must_use]
121    pub fn new() -> Self {
122        let nonce: String = rand::thread_rng()
123            .sample_iter(&Alphanumeric)
124            .take(Self::NONCE_LENGTH)
125            .map(char::from)
126            .collect();
127
128        Self {
129            value: nonce,
130            is_structured: false,
131        }
132    }
133
134    /// Creates a state parameter with embedded custom data.
135    ///
136    /// The state is created as a base64-encoded JSON object containing both
137    /// a secure random nonce and the provided data. This allows passing custom
138    /// information through the OAuth flow (e.g., a return URL).
139    ///
140    /// # Arguments
141    ///
142    /// * `data` - Any serializable data to embed in the state
143    ///
144    /// # Example
145    ///
146    /// ```rust
147    /// use shopify_sdk::auth::oauth::StateParam;
148    /// use serde::{Serialize, Deserialize};
149    ///
150    /// #[derive(Serialize, Deserialize)]
151    /// struct FlowContext {
152    ///     return_url: String,
153    ///     user_id: u64,
154    /// }
155    ///
156    /// let context = FlowContext {
157    ///     return_url: "/dashboard".to_string(),
158    ///     user_id: 12345,
159    /// };
160    ///
161    /// let state = StateParam::with_data(&context);
162    /// // State is base64-encoded, can be safely used in URLs
163    /// println!("State for OAuth: {}", state);
164    /// ```
165    #[must_use]
166    pub fn with_data<T: Serialize>(data: &T) -> Self {
167        let nonce: String = rand::thread_rng()
168            .sample_iter(&Alphanumeric)
169            .take(Self::NONCE_LENGTH)
170            .map(char::from)
171            .collect();
172
173        let structured = StructuredState { nonce, data };
174        let json = serde_json::to_string(&structured).unwrap_or_default();
175        let encoded = BASE64_STANDARD.encode(json.as_bytes());
176
177        Self {
178            value: encoded,
179            is_structured: true,
180        }
181    }
182
183    /// Creates a state parameter from a raw string.
184    ///
185    /// This allows advanced users to provide their own state value. The string
186    /// is used as-is without any processing or validation.
187    ///
188    /// # Arguments
189    ///
190    /// * `raw` - The raw state string to use
191    ///
192    /// # Example
193    ///
194    /// ```rust
195    /// use shopify_sdk::auth::oauth::StateParam;
196    ///
197    /// let state = StateParam::from_raw("custom-state-value");
198    /// assert_eq!(state.as_ref(), "custom-state-value");
199    /// ```
200    #[must_use]
201    pub fn from_raw(raw: impl Into<String>) -> Self {
202        Self {
203            value: raw.into(),
204            is_structured: false,
205        }
206    }
207
208    /// Returns the raw state value.
209    ///
210    /// For simple states (created with `new()` or `from_raw()`), this returns
211    /// the nonce or raw value directly. For structured states (created with
212    /// `with_data()`), this returns the full base64-encoded value.
213    ///
214    /// To extract the actual nonce from a structured state, use [`extract_nonce()`].
215    ///
216    /// # Example
217    ///
218    /// ```rust
219    /// use shopify_sdk::auth::oauth::StateParam;
220    ///
221    /// // Simple state - nonce() returns the 15-char nonce
222    /// let simple = StateParam::new();
223    /// assert_eq!(simple.nonce().len(), 15);
224    ///
225    /// // Structured state - nonce() returns the full encoded value
226    /// // Use extract_nonce() to get the actual nonce
227    /// let structured = StateParam::with_data(&"test");
228    /// let actual_nonce = structured.extract_nonce();
229    /// assert_eq!(actual_nonce.len(), 15);
230    /// ```
231    ///
232    /// [`extract_nonce()`]: Self::extract_nonce
233    #[must_use]
234    pub fn nonce(&self) -> &str {
235        &self.value
236    }
237
238    /// Extracts the embedded data from a structured state.
239    ///
240    /// Attempts to base64-decode the state, parse it as JSON, and deserialize
241    /// the `data` field to the specified type.
242    ///
243    /// # Returns
244    ///
245    /// - `Some(T)` if the state was structured and the data could be deserialized
246    /// - `None` if the state is not structured or deserialization fails
247    ///
248    /// # Example
249    ///
250    /// ```rust
251    /// use shopify_sdk::auth::oauth::StateParam;
252    /// use serde::{Serialize, Deserialize};
253    ///
254    /// #[derive(Serialize, Deserialize, Debug, PartialEq)]
255    /// struct UserData {
256    ///     name: String,
257    /// }
258    ///
259    /// let data = UserData { name: "Alice".to_string() };
260    /// let state = StateParam::with_data(&data);
261    ///
262    /// let extracted: Option<UserData> = state.extract_data();
263    /// assert_eq!(extracted, Some(data));
264    ///
265    /// // Simple states don't have embedded data
266    /// let simple = StateParam::new();
267    /// let extracted: Option<UserData> = simple.extract_data();
268    /// assert!(extracted.is_none());
269    /// ```
270    #[must_use]
271    pub fn extract_data<T: DeserializeOwned>(&self) -> Option<T> {
272        // Attempt to decode base64
273        let decoded = BASE64_STANDARD.decode(self.value.as_bytes()).ok()?;
274        let json_str = String::from_utf8(decoded).ok()?;
275
276        // Parse as structured state
277        let structured: StructuredState<T> = serde_json::from_str(&json_str).ok()?;
278        Some(structured.data)
279    }
280
281    /// Extracts the nonce from a potentially structured state.
282    ///
283    /// For simple states, returns the full value (which is the nonce).
284    /// For structured states, decodes and extracts the actual nonce.
285    ///
286    /// # Returns
287    ///
288    /// The 15-character nonce string, or the full value if parsing fails.
289    ///
290    /// # Example
291    ///
292    /// ```rust
293    /// use shopify_sdk::auth::oauth::StateParam;
294    ///
295    /// // Simple state
296    /// let simple = StateParam::new();
297    /// assert_eq!(simple.extract_nonce().len(), 15);
298    ///
299    /// // Structured state
300    /// let structured = StateParam::with_data(&42);
301    /// assert_eq!(structured.extract_nonce().len(), 15);
302    /// ```
303    #[must_use]
304    pub fn extract_nonce(&self) -> String {
305        if !self.is_structured {
306            return self.value.clone();
307        }
308
309        // Try to decode and extract nonce
310        if let Ok(decoded) = BASE64_STANDARD.decode(self.value.as_bytes()) {
311            if let Ok(json_str) = String::from_utf8(decoded) {
312                if let Ok(nonce_only) = serde_json::from_str::<NonceOnly>(&json_str) {
313                    return nonce_only.nonce;
314                }
315            }
316        }
317
318        // Fallback to full value if parsing fails
319        self.value.clone()
320    }
321}
322
323impl Default for StateParam {
324    fn default() -> Self {
325        Self::new()
326    }
327}
328
329impl fmt::Display for StateParam {
330    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
331        write!(f, "{}", self.value)
332    }
333}
334
335impl AsRef<str> for StateParam {
336    fn as_ref(&self) -> &str {
337        &self.value
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use serde::{Deserialize, Serialize};
345
346    #[test]
347    fn test_new_generates_15_char_alphanumeric_nonce() {
348        let state = StateParam::new();
349        let nonce = state.nonce();
350
351        assert_eq!(nonce.len(), 15);
352        assert!(nonce.chars().all(|c| c.is_ascii_alphanumeric()));
353    }
354
355    #[test]
356    fn test_new_generates_unique_nonces() {
357        let state1 = StateParam::new();
358        let state2 = StateParam::new();
359
360        // Extremely unlikely to generate the same nonce twice
361        assert_ne!(state1.nonce(), state2.nonce());
362    }
363
364    #[test]
365    fn test_with_data_embeds_json_in_base64() {
366        #[derive(Serialize, Deserialize, Debug, PartialEq)]
367        struct TestData {
368            return_url: String,
369        }
370
371        let data = TestData {
372            return_url: "/dashboard".to_string(),
373        };
374        let state = StateParam::with_data(&data);
375
376        // Should be valid base64
377        let decoded = BASE64_STANDARD.decode(state.as_ref().as_bytes());
378        assert!(decoded.is_ok());
379
380        // Should contain valid JSON
381        let json_str = String::from_utf8(decoded.unwrap()).unwrap();
382        assert!(json_str.contains("nonce"));
383        assert!(json_str.contains("data"));
384        assert!(json_str.contains("/dashboard"));
385    }
386
387    #[test]
388    fn test_from_raw_wraps_string_correctly() {
389        let state = StateParam::from_raw("custom-state-123");
390        assert_eq!(state.as_ref(), "custom-state-123");
391        assert_eq!(state.nonce(), "custom-state-123");
392    }
393
394    #[test]
395    fn test_nonce_returns_value_for_simple_state() {
396        let state = StateParam::new();
397        assert_eq!(state.nonce().len(), 15);
398    }
399
400    #[test]
401    fn test_nonce_returns_full_value_for_structured_state() {
402        let state = StateParam::with_data(&"test");
403        // For structured state, nonce() returns the full base64 value
404        // Use extract_nonce() to get the actual nonce
405        assert!(state.nonce().len() > 15);
406        assert_eq!(state.extract_nonce().len(), 15);
407    }
408
409    #[test]
410    fn test_extract_data_returns_embedded_data() {
411        #[derive(Serialize, Deserialize, Debug, PartialEq)]
412        struct FlowData {
413            user_id: u64,
414            redirect_to: String,
415        }
416
417        let original = FlowData {
418            user_id: 12345,
419            redirect_to: "/admin/orders".to_string(),
420        };
421        let state = StateParam::with_data(&original);
422
423        let extracted: Option<FlowData> = state.extract_data();
424        assert_eq!(extracted, Some(original));
425    }
426
427    #[test]
428    fn test_extract_data_returns_none_for_simple_state() {
429        #[derive(Deserialize)]
430        struct SomeData {
431            #[allow(dead_code)]
432            field: String,
433        }
434
435        let state = StateParam::new();
436        let extracted: Option<SomeData> = state.extract_data();
437        assert!(extracted.is_none());
438    }
439
440    #[test]
441    fn test_extract_data_returns_none_for_type_mismatch() {
442        #[derive(Serialize)]
443        struct DataA {
444            field_a: String,
445        }
446
447        #[derive(Deserialize)]
448        struct DataB {
449            #[allow(dead_code)]
450            field_b: i32,
451        }
452
453        let data = DataA {
454            field_a: "test".to_string(),
455        };
456        let state = StateParam::with_data(&data);
457
458        let extracted: Option<DataB> = state.extract_data();
459        assert!(extracted.is_none());
460    }
461
462    #[test]
463    fn test_display_returns_full_state_string() {
464        let state = StateParam::from_raw("display-test");
465        assert_eq!(format!("{}", state), "display-test");
466
467        let state = StateParam::new();
468        assert_eq!(format!("{}", state), state.as_ref());
469    }
470
471    #[test]
472    fn test_as_ref_provides_string_slice() {
473        let state = StateParam::from_raw("ref-test");
474        let s: &str = state.as_ref();
475        assert_eq!(s, "ref-test");
476    }
477
478    #[test]
479    fn test_with_data_handles_various_types() {
480        // String
481        let state = StateParam::with_data(&"simple string");
482        let extracted: Option<String> = state.extract_data();
483        assert_eq!(extracted, Some("simple string".to_string()));
484
485        // Number
486        let state = StateParam::with_data(&42i32);
487        let extracted: Option<i32> = state.extract_data();
488        assert_eq!(extracted, Some(42));
489
490        // Vec
491        let state = StateParam::with_data(&vec![1, 2, 3]);
492        let extracted: Option<Vec<i32>> = state.extract_data();
493        assert_eq!(extracted, Some(vec![1, 2, 3]));
494    }
495
496    #[test]
497    fn test_extract_nonce_from_structured_state() {
498        #[derive(Serialize)]
499        struct Data {
500            value: i32,
501        }
502
503        let state = StateParam::with_data(&Data { value: 42 });
504        let nonce = state.extract_nonce();
505
506        // Nonce should be 15 alphanumeric characters
507        assert_eq!(nonce.len(), 15);
508        assert!(nonce.chars().all(|c| c.is_ascii_alphanumeric()));
509    }
510
511    #[test]
512    fn test_state_param_is_send_sync() {
513        fn assert_send_sync<T: Send + Sync>() {}
514        assert_send_sync::<StateParam>();
515    }
516
517    #[test]
518    fn test_state_param_clone() {
519        let state = StateParam::new();
520        let cloned = state.clone();
521        assert_eq!(state, cloned);
522    }
523
524    #[test]
525    fn test_state_param_eq() {
526        let state1 = StateParam::from_raw("same");
527        let state2 = StateParam::from_raw("same");
528        let state3 = StateParam::from_raw("different");
529
530        assert_eq!(state1, state2);
531        assert_ne!(state1, state3);
532    }
533
534    #[test]
535    fn test_state_param_default() {
536        let state = StateParam::default();
537        assert_eq!(state.nonce().len(), 15);
538    }
539}