Skip to main content

sqlx_data_params/
cursor.rs

1use crate::FilterValue;
2use crate::{IntoParams, Params};
3
4#[cfg(feature = "json")]
5use serde::{Deserialize, Serialize};
6
7// ================================================================================================
8// CORE TYPES
9// ================================================================================================
10
11#[derive(Clone, Debug, PartialEq)]
12#[cfg_attr(feature = "json", derive(Serialize, Deserialize))]
13pub struct CursorEntry {
14    pub value: CursorValue,
15}
16
17/// Client-facing cursor data - contains only the serializable data that goes to the client
18#[derive(Clone, Debug, PartialEq)]
19#[cfg_attr(feature = "json", derive(Serialize, Deserialize))]
20#[derive(Default)]
21pub struct Cursor {
22    #[cfg_attr(feature = "json", serde(skip_serializing_if = "Vec::is_empty"))]
23    pub entries: Vec<CursorEntry>,
24    #[cfg_attr(feature = "json", serde(skip_serializing_if = "Option::is_none"))]
25    pub version: Option<u8>,
26    #[cfg_attr(feature = "json", serde(skip_serializing_if = "Option::is_none"))]
27    pub fingerprint: Option<u64>,
28}
29
30/// Internal cursor params with metadata - contains the cursor data plus internal processing metadata
31#[derive(Clone, Debug, PartialEq)]
32#[derive(Default)]
33pub struct CursorParams {
34    /// Internal direction metadata
35    pub direction: Option<CursorDirection>,
36    /// After, Before, and decoded cursor data - used when building queries
37    pub values: Vec<FilterValue>,
38    /// Optional error message if cursor processing failed
39    pub error: Option<String>,
40}
41
42#[derive(Clone, Debug, PartialEq)]
43pub enum CursorDirection {
44    After,
45    Before,
46}
47
48#[derive(Clone, Debug, PartialEq)]
49#[cfg_attr(feature = "json", derive(Serialize, Deserialize))]
50#[cfg_attr(feature = "json", serde(untagged))]
51pub enum CursorValue {
52    Int(i64),
53    UInt(u64),
54    Float(f64),
55    Bool(bool),
56    String(String),
57}
58
59impl From<i64> for CursorValue {
60    fn from(value: i64) -> Self {
61        CursorValue::Int(value)
62    }
63}
64
65impl From<u64> for CursorValue {
66    fn from(value: u64) -> Self {
67        CursorValue::UInt(value)
68    }
69}
70
71impl From<f64> for CursorValue {
72    fn from(value: f64) -> Self {
73        CursorValue::Float(value)
74    }
75}
76
77impl From<bool> for CursorValue {
78    fn from(value: bool) -> Self {
79        CursorValue::Bool(value)
80    }
81}
82
83impl From<String> for CursorValue {
84    fn from(value: String) -> Self {
85        CursorValue::String(value)
86    }
87}
88
89impl From<&str> for CursorValue {
90    fn from(value: &str) -> Self {
91        CursorValue::String(value.to_string())
92    }
93}
94
95impl From<i32> for CursorValue {
96    fn from(value: i32) -> Self {
97        CursorValue::Int(value as i64)
98    }
99}
100
101impl From<u32> for CursorValue {
102    fn from(value: u32) -> Self {
103        CursorValue::UInt(value as u64)
104    }
105}
106
107impl From<f32> for CursorValue {
108    fn from(value: f32) -> Self {
109        CursorValue::Float(value as f64)
110    }
111}
112
113pub type Result<T, E = CursorError> = ::std::result::Result<T, E>;
114
115// ================================================================================================
116// ERROR HANDLING
117// ================================================================================================
118
119#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
120pub type SqlxError = sqlx_data_integration::Error;
121
122#[derive(Debug, thiserror::Error)]
123#[non_exhaustive]
124enum CursorErrorKind {
125    #[error(transparent)]
126    #[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
127    Sqlx(#[from] SqlxError),
128
129    #[error("Field '{0}' not allowed for cursor pagination")]
130    InvalidField(String),
131
132    #[error("Data is empty")]
133    EmptyData,
134
135    #[error("Encoding cursor failed: {0}")]
136    EncodeError(String),
137
138    #[error("Decoding cursor failed: {0}")]
139    DecodeError(String),
140}
141
142#[derive(Debug)]
143pub struct CursorError(CursorErrorKind);
144
145impl CursorError {
146    /// Create an InvalidField error with automatic type conversion based on features
147    pub fn invalid_field(field: impl Into<String>) -> Self {
148        Self(CursorErrorKind::InvalidField(field.into()))
149    }
150
151    pub fn empty_data() -> Self {
152        Self(CursorErrorKind::EmptyData)
153    }
154
155    pub fn encode_error(msg: impl Into<String>) -> Self {
156        Self(CursorErrorKind::EncodeError(msg.into()))
157    }
158
159    pub fn decode_error(msg: impl Into<String>) -> Self {
160        Self(CursorErrorKind::DecodeError(msg.into()))
161    }
162}
163
164// Convert CursorError to sqlx_data_integration::Error when database features are enabled
165#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
166impl From<SqlxError> for CursorError {
167    fn from(e: SqlxError) -> Self {
168        Self(CursorErrorKind::Sqlx(e))
169    }
170}
171
172impl std::fmt::Display for CursorError {
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        self.0.fmt(f)
175    }
176}
177
178impl std::error::Error for CursorError {
179    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
180        self.0.source()
181    }
182}
183
184#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
185impl From<CursorError> for sqlx_data_integration::Error {
186    fn from(err: CursorError) -> Self {
187        match err.0 {
188            CursorErrorKind::Sqlx(e) => e.into(),
189            other => sqlx_data_integration::Error::Decode(other.into()),
190        }
191    }
192}
193
194// ================================================================================================
195// IMPLEMENTATIONS
196// ================================================================================================
197
198
199
200impl Cursor {
201    pub fn new() -> Self {
202        Self::default()
203    }
204
205    pub fn new_multi(entries: Vec<CursorEntry>) -> Self {
206        Self { entries, version: None, fingerprint: None }
207    }
208
209    pub fn and_field(mut self, value: impl Into<CursorValue>) -> Self {
210        self.entries.push(CursorEntry {
211            value: value.into(),
212        });
213        self
214    }
215}
216
217impl CursorParams {
218    pub fn new(value: FilterValue, direction: CursorDirection) -> Self {
219        Self {
220            values: vec![value],
221            direction: Some(direction),
222            error: None,
223        }
224    }
225
226    pub fn from_values(values: Vec<FilterValue>, direction: CursorDirection) -> Self {
227        Self {
228            values,
229            direction: Some(direction),
230            error: None,
231        }
232    }
233
234    pub fn with_error(direction: CursorDirection, error: impl Into<String>) -> Self {
235        Self {
236            values: vec![],
237            direction: Some(direction),
238            error: Some(error.into()),
239        }
240    }
241
242    pub fn and_field(mut self, value: FilterValue) -> Self {
243        self.values.push(value);
244        self
245    }
246
247    /// Access to the cursor values
248    pub fn values(&self) -> &[FilterValue] {
249        &self.values
250    }
251
252    /// Check if cursor has values
253    pub fn is_empty(&self) -> bool {
254        self.values.is_empty()
255    }
256
257    /// Get the number of values
258    pub fn len(&self) -> usize {
259        self.values.len()
260    }
261
262    /// Check if this cursor has an error
263    pub fn has_error(&self) -> bool {
264        self.error.is_some()
265    }
266
267    /// Get the error message if any
268    pub fn error(&self) -> Option<&str> {
269        self.error.as_deref()
270    }
271
272    /// Generate cursor from a specific item in the data
273    fn generate_cursor<T: CursorSecureExtract>(
274        data: &[T],
275        has_more: bool,
276        sorting_params: &crate::sort::SortingParams,
277        get_item: impl FnOnce(&[T]) -> Option<&T>,
278    ) -> Result<Option<Cursor>> {
279        if !has_more || data.is_empty() {
280            return Ok(None);
281        }
282
283        // Extract field names from sorting parameters
284        let fields: Vec<String> = sorting_params
285            .sorts()
286            .iter()
287            .map(|s| s.field.clone())
288            .collect();
289
290        if fields.is_empty() {
291            return Err(CursorError::invalid_field(
292                "Cursor pagination requires ORDER BY fields",
293            ));
294        }
295
296        let item = get_item(data).ok_or(CursorError::empty_data())?;
297
298        let values = item.extract_whitelisted_fields(&fields)?;
299
300        if values.len() != fields.len() {
301            return Err(CursorError::invalid_field(
302                "Cursor fields mismatch with sorting params",
303            ));
304        }
305
306        let entries: Vec<CursorEntry> = values
307            .into_iter()
308            .map(|value| CursorEntry { value })
309            .collect();
310
311        Ok(Some(Cursor::new_multi(entries)))
312    }
313
314    /// Generate next cursor from the last item in data
315    pub fn generate_next_cursor<T: CursorSecureExtract>(
316        &self,
317        data: &[T],
318        has_next: bool,
319        sorting_params: &crate::sort::SortingParams,
320    ) -> Result<Option<String>> {
321        let cursor = Self::generate_cursor(data, has_next, sorting_params, |data| data.last())?;
322        match cursor {
323            Some(c) => Ok(Some(T::encode(&c)?)),
324            None => Ok(None),
325        }
326    }
327
328    /// Generate prev cursor from the first item in data
329    pub fn generate_prev_cursor<T: CursorSecureExtract>(
330        &self,
331        data: &[T],
332        has_prev: bool,
333        sorting_params: &crate::sort::SortingParams,
334    ) -> Result<Option<String>> {
335        let cursor = Self::generate_cursor(data, has_prev, sorting_params, |data| data.first())?;
336        match cursor {
337            Some(c) => Ok(Some(T::encode(&c)?)),
338            None => Ok(None),
339        }
340    }
341
342    
343}
344
345// ================================================================================================
346// SECURITY TRAITS
347// ================================================================================================
348
349/// **Security-First Cursor Field Whitelist Trait**
350///
351/// This trait enforces a whitelist-based security model for cursor pagination fields.
352/// Implementors MUST explicitly whitelist each allowed field to prevent field injection attacks.
353pub trait CursorSecureExtract {
354    /// **SECURITY CRITICAL**: Extract values ONLY for explicitly whitelisted cursor fields.
355    ///
356    /// **THIS IS A SECURITY WHITELIST** - Only return values for fields you explicitly allow.
357    /// **ALWAYS** return `Err` for any field not in your whitelist to prevent field injection.
358    ///
359    /// # Security Model
360    ///
361    /// This method acts as the primary defense against field injection attacks via cursor pagination.
362    /// Even if malicious field names are injected through `from_encoded()` or other vectors,
363    /// this whitelist ensures only safe, predefined fields can be accessed.
364    ///
365    /// # Implementation Requirements
366    ///
367    /// - **MUST** use explicit `match field.as_str()` with hardcoded field names
368    /// - **MUST** return `Err` for the default case (`_`)
369    /// - **NEVER** use dynamic field resolution or reflection
370    /// - **ONLY** allow fields that are safe for cursor-based ordering
371    ///
372    /// # Example
373    /// ```rust, ignore
374    /// # use sqlx_data_params::{CursorSecureExtract, CursorValue, CursorError, SqlxError};
375    /// type Result<T> = ::std::result::Result<T, SqlxError>;
376    /// struct User {
377    ///     id: i64,
378    ///     name: String,
379    ///     email: String,
380    ///     password_hash: String, // ← NEVER include sensitive fields!
381    /// }
382    ///
383    /// impl CursorSecureExtract for User {
384    ///     #[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
385    ///     fn extract_whitelisted_fields(&self, fields: &[String]) -> Result<Vec<CursorValue>> {
386    ///         let mut values = Vec::with_capacity(fields.len());
387    ///         for field in fields {
388    ///             // 🛡️ SECURITY WHITELIST: Only these fields are allowed
389    ///             match field.as_str() {
390    ///                 "id" => values.push(self.id.into()),           // ✅ Safe: Primary key
391    ///                 "name" => values.push(self.name.clone().into()), // ✅ Safe: Public field
392    ///                 "email" => values.push(self.email.clone().into()), // ✅ Safe: Public field
393    ///                 // password_hash is NOT in whitelist - cannot be accessed via cursor
394    ///                 _ => return Err(CursorError::invalid_field(field.clone()).into()), // 🚫 REJECT: All non-whitelisted fields
395    ///             }
396    ///         }
397    ///         Ok(values)
398    ///     }
399    ///
400    ///     #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mysql")))]
401    ///     fn extract_whitelisted_fields(&self, fields: &[String]) -> Result<Vec<CursorValue>> {
402    ///         let mut values = Vec::with_capacity(fields.len());
403    ///         for field in fields {
404    ///             // 🛡️ SECURITY WHITELIST: Only these fields are allowed
405    ///             match field.as_str() {
406    ///                 "id" => values.push(self.id.into()),           // ✅ Safe: Primary key
407    ///                 "name" => values.push(self.name.clone().into()), // ✅ Safe: Public field
408    ///                 "email" => values.push(self.email.clone().into()), // ✅ Safe: Public field
409    ///                 // password_hash is NOT in whitelist - cannot be accessed via cursor
410    ///                 _ => return Err(CursorError::invalid_field(field.clone())), // 🚫 REJECT: All non-whitelisted fields
411    ///             }
412    ///         }
413    ///         Ok(values)
414    ///     }
415    /// }
416    /// ```
417    ///
418    /// # Security Benefits
419    ///
420    /// - **Field Injection Prevention**: Malicious fields from `from_encoded()` are rejected
421    /// - **Data Exposure Control**: Sensitive fields cannot be accessed via cursor pagination
422    /// - **Explicit Security Model**: Developers must consciously choose which fields to expose
423    /// - **Defense in Depth**: Multiple layers protect against various attack vectors
424    #[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
425    fn extract_whitelisted_fields(
426        &self,
427        fields: &[String],
428    ) -> Result<Vec<CursorValue>, sqlx_data_integration::Error>;
429
430    #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mysql")))]
431    fn extract_whitelisted_fields(&self, fields: &[String]) -> Result<Vec<CursorValue>>;
432
433    /// Encode cursor to string token
434    ///
435    /// Example implementation:
436    /// ```rust,ignore
437    /// fn encode(cursor: &Cursor) -> Result<String, sqlx_data_integration::Error> {
438    ///     use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD as BASE64};
439    ///     let json_bytes = serde_json::to_vec(&cursor)
440    ///         .map_err(|e| CursorError::encode_error(format!("JSON serialization failed: {}", e)))?;
441    ///     Ok(BASE64.encode(json_bytes))
442    /// }
443    /// ```
444    #[cfg(feature = "json")]
445    fn encode(cursor: &Cursor) -> Result<String, sqlx_data_integration::Error>;
446
447    /// Encode cursor to string token (JSON feature disabled)
448    #[cfg(not(feature = "json"))]
449    fn encode(_cursor: &Cursor) -> Result<String>;
450
451    /// Decode string token to FilterValue vector
452    ///
453    /// Example implementation:
454    /// ```rust,ignore
455    /// fn decode(encoded: &str) -> Result<Vec<FilterValue>> {
456    ///     use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD as BASE64};
457    ///     let bytes = BASE64
458    ///         .decode(encoded)
459    ///         .map_err(|e| CursorError::decode_error(format!("Base64 decode failed: {}", e)))?;
460    ///
461    ///     let cursor: Cursor = serde_json::from_slice(&bytes).map_err(|e| {
462    ///         CursorError::decode_error(format!("JSON deserialization failed: {}", e))
463    ///     })?;
464    ///
465    ///     // Convert CursorValue to FilterValue
466    ///     let filter_values: Vec<FilterValue> = cursor.entries.into_iter().map(|entry| {
467    ///         match entry.value {
468    ///             CursorValue::Int(v) => FilterValue::Int(v),
469    ///             CursorValue::UInt(v) => FilterValue::UInt(v),
470    ///             CursorValue::Float(v) => FilterValue::Float(v),
471    ///             CursorValue::Bool(v) => FilterValue::Bool(v),
472    ///             CursorValue::String(v) => v.into(), // Or Whatever conversion is appropriate
473    ///         }
474    ///     }).collect();
475    ///
476    ///     Ok(filter_values)
477    /// }
478    /// ```
479    #[cfg(feature = "json")]
480    fn decode(encoded: &str) -> Result<Vec<FilterValue>, sqlx_data_integration::Error>;
481
482    /// Decode string token to FilterValue vector (JSON feature disabled)
483    #[cfg(not(feature = "json"))]
484    fn decode(_encoded: &str) -> Result<Vec<FilterValue>>;
485}
486
487// ================================================================================================
488// PARAMS INTEGRATION
489// ================================================================================================
490
491impl IntoParams for CursorParams {
492    fn into_params(self) -> Params {
493        let per_page = 20; // Default value
494        let pagination = crate::pagination::Pagination::Cursor(self);
495        Params {
496            filters: None,
497            search: None,
498            sort_by: None,
499            pagination: Some(pagination),
500            limit: Some(crate::pagination::LimitParam(per_page)),
501            offset: None,
502        }
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509
510    #[test]
511    fn test_cursor_builder_pattern() {
512        let cursor =
513            CursorParams::new(FilterValue::String("alice".into()), CursorDirection::Before)
514                .and_field(FilterValue::Int(25))
515                .and_field(FilterValue::Float(99.5));
516
517        assert_eq!(cursor.len(), 3);
518        assert_eq!(cursor.direction.unwrap(), CursorDirection::Before);
519    }
520
521    #[test]
522    fn test_cursor_state_detection() {
523        let cursor_with_data = CursorParams::new(FilterValue::Int(123), CursorDirection::After);
524        assert!(!cursor_with_data.is_empty());
525        assert!(!cursor_with_data.has_error());
526
527        let cursor_with_error = CursorParams::with_error(CursorDirection::After, "decode failed");
528        assert!(cursor_with_error.is_empty());
529        assert!(cursor_with_error.has_error());
530    }
531
532    #[test]
533    fn test_cursor_values() {
534        let cursor = CursorParams::new(FilterValue::Int(123), CursorDirection::After)
535            .and_field(FilterValue::String("test".into()));
536
537        assert_eq!(cursor.len(), 2);
538        assert_eq!(cursor.values().len(), 2);
539        assert_eq!(cursor.direction, Some(CursorDirection::After));
540    }
541
542    #[test]
543    fn test_error_workflow() {
544        let cursor_ok = CursorParams::new(FilterValue::Int(123), CursorDirection::After);
545        assert!(!cursor_ok.has_error());
546        assert_eq!(cursor_ok.error(), None);
547
548        let cursor_err = CursorParams::with_error(CursorDirection::Before, "Invalid token");
549        assert!(cursor_err.has_error());
550        assert_eq!(cursor_err.error(), Some("Invalid token"));
551    }
552}