Skip to main content

sqlx_data_params/
sort.rs

1use crate::{IntoParams, Params};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4#[derive(Default)]
5pub enum SortDirection {
6    #[default]
7    Asc,
8    Desc,
9}
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12#[derive(Default)]
13pub enum NullOrdering {
14    First,
15    Last,
16    #[default]
17    Default,
18}
19
20/// Internal enum to track whether a sort field was added safely or unsafely
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
22enum SortSafety {
23    #[default]
24    Safe, // Added via asc() or desc() with compile-time validation
25    Unsafe, // Added via asc_unsafe() or desc_unsafe() with runtime validation needed
26}
27
28impl SortDirection {
29    /// Returns true if this is ASC direction
30    pub fn is_asc(self) -> bool {
31        matches!(self, Self::Asc)
32    }
33
34    /// Returns true if this is DESC direction
35    pub fn is_desc(self) -> bool {
36        matches!(self, Self::Desc)
37    }
38
39    pub fn flip(self) -> Self {
40        match self {
41            Self::Asc => Self::Desc,
42            Self::Desc => Self::Asc,
43        }
44    }
45}
46
47
48impl NullOrdering {
49    /// Returns Some(true) for First, Some(false) for Last, None for Default
50    pub fn as_bool_option(self) -> Option<bool> {
51        match self {
52            Self::First => Some(true),
53            Self::Last => Some(false),
54            Self::Default => None,
55        }
56    }
57
58    /// Returns true if nulls should come first
59    pub fn is_first(self) -> bool {
60        matches!(self, Self::First)
61    }
62
63    /// Returns true if nulls should come last
64    pub fn is_last(self) -> bool {
65        matches!(self, Self::Last)
66    }
67
68    /// Returns true if using default null ordering
69    pub fn is_default(self) -> bool {
70        matches!(self, Self::Default)
71    }
72}
73
74
75#[derive(Debug, Clone)]
76pub struct Sort {
77    pub field: String,
78    pub direction: SortDirection,
79    pub nulls: NullOrdering,
80    safety_type: SortSafety,
81}
82
83impl Sort {
84    pub fn new(field: impl Into<String>, direction: SortDirection) -> Self {
85        Self {
86            field: field.into(),
87            direction,
88            nulls: NullOrdering::default(),
89            safety_type: SortSafety::Safe,
90        }
91    }
92
93    pub fn new_unsafe(field: impl Into<String>, direction: SortDirection) -> Self {
94        Self {
95            field: field.into(),
96            direction,
97            nulls: NullOrdering::default(),
98            safety_type: SortSafety::Unsafe,
99        }
100    }
101
102    /// Returns true if this sort was added via unsafe methods
103    pub fn is_unsafe(&self) -> bool {
104        matches!(self.safety_type, SortSafety::Unsafe)
105    }
106
107    pub fn asc(field: impl Into<String>) -> Self {
108        Self::new(field, SortDirection::Asc)
109    }
110
111    pub fn desc(field: impl Into<String>) -> Self {
112        Self::new(field, SortDirection::Desc)
113    }
114
115    pub fn asc_unsafe(field: impl Into<String>) -> Self {
116        Self::new_unsafe(field, SortDirection::Asc)
117    }
118
119    pub fn desc_unsafe(field: impl Into<String>) -> Self {
120        Self::new_unsafe(field, SortDirection::Desc)
121    }
122
123    pub fn nulls_first(mut self) -> Self {
124        self.nulls = NullOrdering::First;
125        self
126    }
127
128    pub fn nulls_last(mut self) -> Self {
129        self.nulls = NullOrdering::Last;
130        self
131    }
132
133    pub fn nulls_default(mut self) -> Self {
134        self.nulls = NullOrdering::Default;
135        self
136    }
137
138    /// Returns true if this is ASC direction
139    pub fn is_asc(&self) -> bool {
140        self.direction.is_asc()
141    }
142
143    /// Returns true if this is DESC direction
144    pub fn is_desc(&self) -> bool {
145        self.direction.is_desc()
146    }
147
148    /// Returns nulls ordering as Option<bool>
149    pub fn nulls_as_bool(&self) -> Option<bool> {
150        self.nulls.as_bool_option()
151    }
152}
153
154#[derive(Debug, Clone, Default)]
155pub struct SortingParams {
156    sorts: Vec<Sort>,
157    unsafe_fields: Vec<String>,
158    allowed_columns: Option<&'static [&'static str]>,
159}
160
161impl SortingParams {
162    pub fn new() -> Self {
163        Self {
164            sorts: Vec::new(),
165            unsafe_fields: Vec::new(),
166            allowed_columns: None,
167        }
168    }
169
170    /// Get readonly access to sorts
171    pub fn sorts(&self) -> &[Sort] {
172        &self.sorts
173    }
174
175    /// Check if there are any unsafe fields that need validation
176    pub fn has_unsafe_fields(&self) -> bool {
177        !self.unsafe_fields.is_empty()
178    }
179
180    /// Set allowed columns for unsafe field validation
181    pub fn with_allowed_columns(mut self, columns: &'static [&'static str]) -> Self {
182        self.allowed_columns = Some(columns);
183        self
184    }
185
186    /// Validate all unsafe fields against the whitelist
187    pub fn validate_fields(&self) -> Result<(), String> {
188        if !self.has_unsafe_fields() {
189            return Ok(());
190        }
191
192        let allowed = match self.allowed_columns {
193            Some(cols) => cols,
194            None => {
195                return Err("Unsafe fields present but no allowed_columns specified".to_string());
196            }
197        };
198
199        // SECURITY NOTE:
200        // This uses slice::contains, which performs an exact, case-sensitive match.
201        // It does NOT perform substring matching.
202        // Do NOT replace with str::contains, starts_with, regex, etc. (SQL injection risk)
203        for field in &self.unsafe_fields {
204            if !allowed.contains(&field.as_str()) {
205                return Err(format!("Field '{}' is not in allowed columns list", field));
206            }
207        }
208
209        Ok(())
210    }
211
212    pub fn push(mut self, sort: Sort) -> Self {
213        // Track unsafe fields
214        if sort.is_unsafe() {
215            self.unsafe_fields.push(sort.field.clone());
216        }
217        self.sorts.push(sort);
218        self
219    }
220
221    pub fn sort_by(mut self, field: impl Into<String>, direction: SortDirection) -> Self {
222        self.sorts.push(Sort::new(field, direction));
223        self
224    }
225
226    pub fn asc(self, field: impl Into<String>) -> Self {
227        self.sort_by(field, SortDirection::Asc)
228    }
229
230    pub fn desc(self, field: impl Into<String>) -> Self {
231        self.sort_by(field, SortDirection::Desc)
232    }
233
234    pub fn is_empty(&self) -> bool {
235        self.sorts.is_empty()
236    }
237
238    /// Apply NULLS FIRST to the last added sort
239    pub fn apply_nulls_first(mut self) -> Self {
240        if let Some(last_sort) = self.sorts.last_mut() {
241            last_sort.nulls = NullOrdering::First;
242        }
243        self
244    }
245
246    /// Apply NULLS LAST to the last added sort
247    pub fn apply_nulls_last(mut self) -> Self {
248        if let Some(last_sort) = self.sorts.last_mut() {
249            last_sort.nulls = NullOrdering::Last;
250        }
251        self
252    }
253
254    /// Apply default null ordering to the last added sort
255    pub fn apply_nulls_default(mut self) -> Self {
256        if let Some(last_sort) = self.sorts.last_mut() {
257            last_sort.nulls = NullOrdering::Default;
258        }
259        self
260    }
261
262    /// Combine with another SortingParams, extending the sorts list
263    pub fn extend_with(mut self, other: SortingParams) -> Self {
264        self.sorts.extend(other.sorts);
265        self.unsafe_fields.extend(other.unsafe_fields);
266        // Preserve allowed_columns from self if present, otherwise use other's
267        if self.allowed_columns.is_none() {
268            self.allowed_columns = other.allowed_columns;
269        }
270        self
271    }
272}
273
274impl IntoParams for SortingParams {
275    fn into_params(self) -> Params {
276        Params {
277            filters: None,
278            search: None,
279            sort_by: Some(self),
280            pagination: None,
281            limit: None,
282            offset: None,
283        }
284    }
285}