1use crate::FilterValue;
2use crate::{IntoParams, Params};
3
4#[cfg(feature = "json")]
5use serde::{Deserialize, Serialize};
6
7#[derive(Clone, Debug, PartialEq)]
12#[cfg_attr(feature = "json", derive(Serialize, Deserialize))]
13pub struct CursorEntry {
14 pub value: CursorValue,
15}
16
17#[derive(Clone, Debug, PartialEq)]
19#[cfg_attr(feature = "json", derive(Serialize, Deserialize))]
20#[derive(Default)]
21pub struct Cursor {
22 #[cfg_attr(feature = "json", serde(skip_serializing_if = "Vec::is_empty"))]
23 pub entries: Vec<CursorEntry>,
24 #[cfg_attr(feature = "json", serde(skip_serializing_if = "Option::is_none"))]
25 pub version: Option<u8>,
26 #[cfg_attr(feature = "json", serde(skip_serializing_if = "Option::is_none"))]
27 pub fingerprint: Option<u64>,
28}
29
30#[derive(Clone, Debug, PartialEq)]
32#[derive(Default)]
33pub struct CursorParams {
34 pub direction: Option<CursorDirection>,
36 pub values: Vec<FilterValue>,
38 pub error: Option<String>,
40}
41
42#[derive(Clone, Debug, PartialEq)]
43pub enum CursorDirection {
44 After,
45 Before,
46}
47
48#[derive(Clone, Debug, PartialEq)]
49#[cfg_attr(feature = "json", derive(Serialize, Deserialize))]
50#[cfg_attr(feature = "json", serde(untagged))]
51pub enum CursorValue {
52 Int(i64),
53 UInt(u64),
54 Float(f64),
55 Bool(bool),
56 String(String),
57}
58
59impl From<i64> for CursorValue {
60 fn from(value: i64) -> Self {
61 CursorValue::Int(value)
62 }
63}
64
65impl From<u64> for CursorValue {
66 fn from(value: u64) -> Self {
67 CursorValue::UInt(value)
68 }
69}
70
71impl From<f64> for CursorValue {
72 fn from(value: f64) -> Self {
73 CursorValue::Float(value)
74 }
75}
76
77impl From<bool> for CursorValue {
78 fn from(value: bool) -> Self {
79 CursorValue::Bool(value)
80 }
81}
82
83impl From<String> for CursorValue {
84 fn from(value: String) -> Self {
85 CursorValue::String(value)
86 }
87}
88
89impl From<&str> for CursorValue {
90 fn from(value: &str) -> Self {
91 CursorValue::String(value.to_string())
92 }
93}
94
95impl From<i32> for CursorValue {
96 fn from(value: i32) -> Self {
97 CursorValue::Int(value as i64)
98 }
99}
100
101impl From<u32> for CursorValue {
102 fn from(value: u32) -> Self {
103 CursorValue::UInt(value as u64)
104 }
105}
106
107impl From<f32> for CursorValue {
108 fn from(value: f32) -> Self {
109 CursorValue::Float(value as f64)
110 }
111}
112
113pub type Result<T, E = CursorError> = ::std::result::Result<T, E>;
114
115#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
120pub type SqlxError = sqlx_data_integration::Error;
121
122#[derive(Debug, thiserror::Error)]
123#[non_exhaustive]
124enum CursorErrorKind {
125 #[error(transparent)]
126 #[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
127 Sqlx(#[from] SqlxError),
128
129 #[error("Field '{0}' not allowed for cursor pagination")]
130 InvalidField(String),
131
132 #[error("Data is empty")]
133 EmptyData,
134
135 #[error("Encoding cursor failed: {0}")]
136 EncodeError(String),
137
138 #[error("Decoding cursor failed: {0}")]
139 DecodeError(String),
140}
141
142#[derive(Debug)]
143pub struct CursorError(CursorErrorKind);
144
145impl CursorError {
146 pub fn invalid_field(field: impl Into<String>) -> Self {
148 Self(CursorErrorKind::InvalidField(field.into()))
149 }
150
151 pub fn empty_data() -> Self {
152 Self(CursorErrorKind::EmptyData)
153 }
154
155 pub fn encode_error(msg: impl Into<String>) -> Self {
156 Self(CursorErrorKind::EncodeError(msg.into()))
157 }
158
159 pub fn decode_error(msg: impl Into<String>) -> Self {
160 Self(CursorErrorKind::DecodeError(msg.into()))
161 }
162}
163
164#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
166impl From<SqlxError> for CursorError {
167 fn from(e: SqlxError) -> Self {
168 Self(CursorErrorKind::Sqlx(e))
169 }
170}
171
172impl std::fmt::Display for CursorError {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 self.0.fmt(f)
175 }
176}
177
178impl std::error::Error for CursorError {
179 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
180 self.0.source()
181 }
182}
183
184#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
185impl From<CursorError> for sqlx_data_integration::Error {
186 fn from(err: CursorError) -> Self {
187 match err.0 {
188 CursorErrorKind::Sqlx(e) => e.into(),
189 other => sqlx_data_integration::Error::Decode(other.into()),
190 }
191 }
192}
193
194impl Cursor {
201 pub fn new() -> Self {
202 Self::default()
203 }
204
205 pub fn new_multi(entries: Vec<CursorEntry>) -> Self {
206 Self { entries, version: None, fingerprint: None }
207 }
208
209 pub fn and_field(mut self, value: impl Into<CursorValue>) -> Self {
210 self.entries.push(CursorEntry {
211 value: value.into(),
212 });
213 self
214 }
215}
216
217impl CursorParams {
218 pub fn new(value: FilterValue, direction: CursorDirection) -> Self {
219 Self {
220 values: vec![value],
221 direction: Some(direction),
222 error: None,
223 }
224 }
225
226 pub fn from_values(values: Vec<FilterValue>, direction: CursorDirection) -> Self {
227 Self {
228 values,
229 direction: Some(direction),
230 error: None,
231 }
232 }
233
234 pub fn with_error(direction: CursorDirection, error: impl Into<String>) -> Self {
235 Self {
236 values: vec![],
237 direction: Some(direction),
238 error: Some(error.into()),
239 }
240 }
241
242 pub fn and_field(mut self, value: FilterValue) -> Self {
243 self.values.push(value);
244 self
245 }
246
247 pub fn values(&self) -> &[FilterValue] {
249 &self.values
250 }
251
252 pub fn is_empty(&self) -> bool {
254 self.values.is_empty()
255 }
256
257 pub fn len(&self) -> usize {
259 self.values.len()
260 }
261
262 pub fn has_error(&self) -> bool {
264 self.error.is_some()
265 }
266
267 pub fn error(&self) -> Option<&str> {
269 self.error.as_deref()
270 }
271
272 fn generate_cursor<T: CursorSecureExtract>(
274 data: &[T],
275 has_more: bool,
276 sorting_params: &crate::sort::SortingParams,
277 get_item: impl FnOnce(&[T]) -> Option<&T>,
278 ) -> Result<Option<Cursor>> {
279 if !has_more || data.is_empty() {
280 return Ok(None);
281 }
282
283 let fields: Vec<String> = sorting_params
285 .sorts()
286 .iter()
287 .map(|s| s.field.clone())
288 .collect();
289
290 if fields.is_empty() {
291 return Err(CursorError::invalid_field(
292 "Cursor pagination requires ORDER BY fields",
293 ));
294 }
295
296 let item = get_item(data).ok_or(CursorError::empty_data())?;
297
298 let values = item.extract_whitelisted_fields(&fields)?;
299
300 if values.len() != fields.len() {
301 return Err(CursorError::invalid_field(
302 "Cursor fields mismatch with sorting params",
303 ));
304 }
305
306 let entries: Vec<CursorEntry> = values
307 .into_iter()
308 .map(|value| CursorEntry { value })
309 .collect();
310
311 Ok(Some(Cursor::new_multi(entries)))
312 }
313
314 pub fn generate_next_cursor<T: CursorSecureExtract>(
316 &self,
317 data: &[T],
318 has_next: bool,
319 sorting_params: &crate::sort::SortingParams,
320 ) -> Result<Option<String>> {
321 let cursor = Self::generate_cursor(data, has_next, sorting_params, |data| data.last())?;
322 match cursor {
323 Some(c) => Ok(Some(T::encode(&c)?)),
324 None => Ok(None),
325 }
326 }
327
328 pub fn generate_prev_cursor<T: CursorSecureExtract>(
330 &self,
331 data: &[T],
332 has_prev: bool,
333 sorting_params: &crate::sort::SortingParams,
334 ) -> Result<Option<String>> {
335 let cursor = Self::generate_cursor(data, has_prev, sorting_params, |data| data.first())?;
336 match cursor {
337 Some(c) => Ok(Some(T::encode(&c)?)),
338 None => Ok(None),
339 }
340 }
341
342
343}
344
345pub trait CursorSecureExtract {
354 #[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
425 fn extract_whitelisted_fields(
426 &self,
427 fields: &[String],
428 ) -> Result<Vec<CursorValue>, sqlx_data_integration::Error>;
429
430 #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mysql")))]
431 fn extract_whitelisted_fields(&self, fields: &[String]) -> Result<Vec<CursorValue>>;
432
433 #[cfg(feature = "json")]
445 fn encode(cursor: &Cursor) -> Result<String, sqlx_data_integration::Error>;
446
447 #[cfg(not(feature = "json"))]
449 fn encode(_cursor: &Cursor) -> Result<String>;
450
451 #[cfg(feature = "json")]
480 fn decode(encoded: &str) -> Result<Vec<FilterValue>, sqlx_data_integration::Error>;
481
482 #[cfg(not(feature = "json"))]
484 fn decode(_encoded: &str) -> Result<Vec<FilterValue>>;
485}
486
487impl IntoParams for CursorParams {
492 fn into_params(self) -> Params {
493 let per_page = 20; let pagination = crate::pagination::Pagination::Cursor(self);
495 Params {
496 filters: None,
497 search: None,
498 sort_by: None,
499 pagination: Some(pagination),
500 limit: Some(crate::pagination::LimitParam(per_page)),
501 offset: None,
502 }
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_cursor_builder_pattern() {
512 let cursor =
513 CursorParams::new(FilterValue::String("alice".into()), CursorDirection::Before)
514 .and_field(FilterValue::Int(25))
515 .and_field(FilterValue::Float(99.5));
516
517 assert_eq!(cursor.len(), 3);
518 assert_eq!(cursor.direction.unwrap(), CursorDirection::Before);
519 }
520
521 #[test]
522 fn test_cursor_state_detection() {
523 let cursor_with_data = CursorParams::new(FilterValue::Int(123), CursorDirection::After);
524 assert!(!cursor_with_data.is_empty());
525 assert!(!cursor_with_data.has_error());
526
527 let cursor_with_error = CursorParams::with_error(CursorDirection::After, "decode failed");
528 assert!(cursor_with_error.is_empty());
529 assert!(cursor_with_error.has_error());
530 }
531
532 #[test]
533 fn test_cursor_values() {
534 let cursor = CursorParams::new(FilterValue::Int(123), CursorDirection::After)
535 .and_field(FilterValue::String("test".into()));
536
537 assert_eq!(cursor.len(), 2);
538 assert_eq!(cursor.values().len(), 2);
539 assert_eq!(cursor.direction, Some(CursorDirection::After));
540 }
541
542 #[test]
543 fn test_error_workflow() {
544 let cursor_ok = CursorParams::new(FilterValue::Int(123), CursorDirection::After);
545 assert!(!cursor_ok.has_error());
546 assert_eq!(cursor_ok.error(), None);
547
548 let cursor_err = CursorParams::with_error(CursorDirection::Before, "Invalid token");
549 assert!(cursor_err.has_error());
550 assert_eq!(cursor_err.error(), Some("Invalid token"));
551 }
552}