shaperail_runtime/db/
pagination.rs1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone)]
5pub enum PageRequest {
6 Cursor { after: Option<String>, limit: i64 },
8 Offset { offset: i64, limit: i64 },
10}
11
12impl PageRequest {
13 pub const DEFAULT_LIMIT: i64 = 25;
15 pub const MAX_LIMIT: i64 = 100;
17
18 pub fn clamped_limit(limit: Option<i64>) -> i64 {
20 limit
21 .unwrap_or(Self::DEFAULT_LIMIT)
22 .clamp(1, Self::MAX_LIMIT)
23 }
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct CursorPage {
29 pub cursor: Option<String>,
31 pub has_more: bool,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct OffsetPage {
38 pub offset: i64,
40 pub limit: i64,
42 pub total: i64,
44}
45
46impl PageRequest {
47 pub fn apply_cursor_to_sql(
53 &self,
54 sql: &mut String,
55 has_where: bool,
56 param_offset: usize,
57 ) -> usize {
58 match self {
59 PageRequest::Cursor { after, limit } => {
60 let mut offset = param_offset;
61 if after.is_some() {
62 if has_where {
63 sql.push_str(" AND ");
64 } else {
65 sql.push_str(" WHERE ");
66 }
67 sql.push_str(&format!("\"id\" > ${offset}"));
68 offset += 1;
69 }
70 sql.push_str(" ORDER BY \"id\" ASC");
71 sql.push_str(&format!(" LIMIT {}", limit + 1));
73 offset
74 }
75 PageRequest::Offset { offset: off, limit } => {
76 sql.push_str(&format!(" LIMIT {limit} OFFSET {off}"));
77 param_offset
78 }
79 }
80 }
81}
82
83pub fn decode_cursor(cursor: &str) -> Result<String, shaperail_core::ShaperailError> {
85 use std::str;
86 let bytes = base64_decode(cursor).map_err(|_| {
88 shaperail_core::ShaperailError::Validation(vec![shaperail_core::FieldError {
89 field: "cursor".to_string(),
90 message: "Invalid cursor format".to_string(),
91 code: "invalid_cursor".to_string(),
92 }])
93 })?;
94 let id = str::from_utf8(&bytes).map_err(|_| {
95 shaperail_core::ShaperailError::Validation(vec![shaperail_core::FieldError {
96 field: "cursor".to_string(),
97 message: "Invalid cursor encoding".to_string(),
98 code: "invalid_cursor".to_string(),
99 }])
100 })?;
101 Ok(id.to_string())
102}
103
104pub fn encode_cursor(id: &str) -> String {
106 base64_encode(id.as_bytes())
107}
108
109fn base64_encode(data: &[u8]) -> String {
111 const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
112 let mut result = String::with_capacity(data.len().div_ceil(3) * 4);
113 for chunk in data.chunks(3) {
114 let b0 = chunk[0] as u32;
115 let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
116 let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
117 let triple = (b0 << 16) | (b1 << 8) | b2;
118 result.push(CHARS[((triple >> 18) & 0x3F) as usize] as char);
119 result.push(CHARS[((triple >> 12) & 0x3F) as usize] as char);
120 if chunk.len() > 1 {
121 result.push(CHARS[((triple >> 6) & 0x3F) as usize] as char);
122 } else {
123 result.push('=');
124 }
125 if chunk.len() > 2 {
126 result.push(CHARS[(triple & 0x3F) as usize] as char);
127 } else {
128 result.push('=');
129 }
130 }
131 result
132}
133
134fn base64_decode(input: &str) -> Result<Vec<u8>, &'static str> {
135 fn char_to_val(c: u8) -> Result<u32, &'static str> {
136 match c {
137 b'A'..=b'Z' => Ok((c - b'A') as u32),
138 b'a'..=b'z' => Ok((c - b'a' + 26) as u32),
139 b'0'..=b'9' => Ok((c - b'0' + 52) as u32),
140 b'+' => Ok(62),
141 b'/' => Ok(63),
142 b'=' => Ok(0),
143 _ => Err("invalid base64 character"),
144 }
145 }
146
147 let bytes = input.as_bytes();
148 if !bytes.len().is_multiple_of(4) {
149 return Err("invalid base64 length");
150 }
151
152 let mut result = Vec::with_capacity(bytes.len() / 4 * 3);
153 for chunk in bytes.chunks(4) {
154 let a = char_to_val(chunk[0])?;
155 let b = char_to_val(chunk[1])?;
156 let c = char_to_val(chunk[2])?;
157 let d = char_to_val(chunk[3])?;
158 let triple = (a << 18) | (b << 12) | (c << 6) | d;
159 result.push((triple >> 16) as u8);
160 if chunk[2] != b'=' {
161 result.push((triple >> 8) as u8);
162 }
163 if chunk[3] != b'=' {
164 result.push(triple as u8);
165 }
166 }
167 Ok(result)
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn clamped_limit_default() {
176 assert_eq!(PageRequest::clamped_limit(None), 25);
177 }
178
179 #[test]
180 fn clamped_limit_within_range() {
181 assert_eq!(PageRequest::clamped_limit(Some(10)), 10);
182 assert_eq!(PageRequest::clamped_limit(Some(50)), 50);
183 }
184
185 #[test]
186 fn clamped_limit_too_high() {
187 assert_eq!(PageRequest::clamped_limit(Some(500)), 100);
188 }
189
190 #[test]
191 fn clamped_limit_too_low() {
192 assert_eq!(PageRequest::clamped_limit(Some(0)), 1);
193 assert_eq!(PageRequest::clamped_limit(Some(-5)), 1);
194 }
195
196 #[test]
197 fn cursor_encode_decode_roundtrip() {
198 let id = "550e8400-e29b-41d4-a716-446655440000";
199 let encoded = encode_cursor(id);
200 let decoded = decode_cursor(&encoded).unwrap();
201 assert_eq!(decoded, id);
202 }
203
204 #[test]
205 fn invalid_cursor_returns_error() {
206 let result = decode_cursor("!!invalid!!");
207 assert!(result.is_err());
208 }
209
210 #[test]
211 fn cursor_pagination_sql_no_cursor() {
212 let page = PageRequest::Cursor {
213 after: None,
214 limit: 25,
215 };
216 let mut sql = "SELECT * FROM users".to_string();
217 let offset = page.apply_cursor_to_sql(&mut sql, false, 1);
218
219 assert_eq!(sql, "SELECT * FROM users ORDER BY \"id\" ASC LIMIT 26");
220 assert_eq!(offset, 1);
221 }
222
223 #[test]
224 fn cursor_pagination_sql_with_cursor() {
225 let page = PageRequest::Cursor {
226 after: Some("some-uuid".to_string()),
227 limit: 10,
228 };
229 let mut sql = "SELECT * FROM users".to_string();
230 let offset = page.apply_cursor_to_sql(&mut sql, false, 1);
231
232 assert_eq!(
233 sql,
234 "SELECT * FROM users WHERE \"id\" > $1 ORDER BY \"id\" ASC LIMIT 11"
235 );
236 assert_eq!(offset, 2);
237 }
238
239 #[test]
240 fn cursor_pagination_with_existing_where() {
241 let page = PageRequest::Cursor {
242 after: Some("some-uuid".to_string()),
243 limit: 10,
244 };
245 let mut sql = "SELECT * FROM users WHERE \"role\" = $1".to_string();
246 let offset = page.apply_cursor_to_sql(&mut sql, true, 2);
247
248 assert!(sql.contains("AND \"id\" > $2"));
249 assert!(sql.contains("LIMIT 11"));
250 assert_eq!(offset, 3);
251 }
252
253 #[test]
254 fn offset_pagination_sql() {
255 let page = PageRequest::Offset {
256 offset: 20,
257 limit: 10,
258 };
259 let mut sql = "SELECT * FROM users".to_string();
260 let offset = page.apply_cursor_to_sql(&mut sql, false, 1);
261
262 assert_eq!(sql, "SELECT * FROM users LIMIT 10 OFFSET 20");
263 assert_eq!(offset, 1);
264 }
265}