Skip to main content

sqlx_paginated/paginated_query_as/builders/query_builders/
query_builder.rs

1use crate::paginated_query_as::internal::{ColumnProtection, QueryDialect};
2use crate::paginated_query_as::models::{QueryFilterCondition, QueryFilterOperator};
3use crate::QueryParams;
4use chrono::{DateTime, Utc};
5use serde::Serialize;
6use sqlx::{Arguments, Database, Encode, Type};
7use std::marker::PhantomData;
8
9pub struct QueryBuilder<'q, T, DB: Database> {
10    pub conditions: Vec<String>,
11    pub arguments: DB::Arguments<'q>,
12    pub(crate) valid_columns: Vec<String>,
13    pub(crate) protection: Option<ColumnProtection>,
14    pub(crate) protection_enabled: bool,
15    pub(crate) dialect: Box<dyn QueryDialect>,
16    pub(crate) _phantom: PhantomData<&'q T>,
17}
18
19impl<'q, T, DB> QueryBuilder<'q, T, DB>
20where
21    T: Default + Serialize,
22    DB: Database,
23    String: for<'a> Encode<'a, DB> + Type<DB>,
24{
25    /// Checks if a column exists in the list of valid columns for T struct.
26    ///
27    /// # Arguments
28    ///
29    /// * `column` - The name of the column to check
30    ///
31    /// # Returns
32    ///
33    /// Returns `true` if the column exists in the valid columns list, `false` otherwise.
34    pub(crate) fn has_column(&self, column: &str) -> bool {
35        self.valid_columns.contains(&column.to_string())
36    }
37
38    fn is_column_safe(&self, column: &str) -> bool {
39        let column_exists = self.has_column(column);
40
41        if !self.protection_enabled {
42            return column_exists;
43        }
44
45        match &self.protection {
46            Some(protection) => column_exists && protection.is_safe(column),
47            None => column_exists,
48        }
49    }
50
51    /// Adds search functionality to the query by creating LIKE conditions for specified columns.
52    ///
53    /// # Arguments
54    ///
55    /// * `params` - Query parameters containing search text and columns to search in
56    ///
57    /// # Details
58    ///
59    /// - Only searches in columns that are both specified and considered safe
60    /// - Creates case-insensitive LIKE conditions with wildcards
61    /// - Multiple search columns are combined with OR operators
62    /// - Empty search text or no valid columns results in no conditions being added
63    ///
64    /// # Returns
65    ///
66    /// Returns self for method chaining
67    ///
68    /// # Example
69    ///
70    /// ```rust
71    /// use sqlx::Postgres;
72    /// use serde::{Serialize};
73    /// use sqlx_paginated::{QueryBuilder, QueryParamsBuilder};
74    ///
75    /// #[derive(Serialize, Default)]
76    /// struct UserExample {
77    ///     name: String
78    /// }
79    ///
80    /// let initial_params = QueryParamsBuilder::<UserExample>::new()
81    ///         .with_search("john", vec!["name", "email"])
82    ///         .build();
83    /// let query_builder = QueryBuilder::<UserExample, Postgres>::new()
84    ///     .with_search(&initial_params)
85    ///     .build();
86    /// ```
87    pub fn with_search(mut self, params: &QueryParams<T>) -> Self {
88        if let Some(search) = &params.search.search {
89            if let Some(columns) = &params.search.search_columns {
90                let valid_search_columns: Vec<&String> = columns
91                    .iter()
92                    .filter(|column| self.is_column_safe(column))
93                    .collect();
94
95                if !valid_search_columns.is_empty() && !search.trim().is_empty() {
96                    let pattern = format!("%{}%", search);
97                    let use_lower = search.is_ascii();
98
99                    let search_conditions: Vec<String> = valid_search_columns
100                        .iter()
101                        .enumerate()
102                        .map(|(idx, column)| {
103                            let table_column = self.dialect.quote_identifier(column);
104                            let placeholder =
105                                self.dialect.placeholder(self.arguments.len() + idx + 1);
106                            if use_lower {
107                                format!("LOWER({}) LIKE LOWER({})", table_column, placeholder)
108                            } else {
109                                format!("{} LIKE {}", table_column, placeholder)
110                            }
111                        })
112                        .collect();
113
114                    if !search_conditions.is_empty() {
115                        self.conditions
116                            .push(format!("({})", search_conditions.join(" OR ")));
117                        for _ in 0..valid_search_columns.len() {
118                            self.arguments.add(pattern.clone()).unwrap_or_default();
119                        }
120                    }
121                }
122            }
123        }
124        self
125    }
126
127    /// Adds filter conditions to the query with support for various operators.
128    ///
129    /// # Arguments
130    ///
131    /// * `params` - Query parameters containing filters with operators
132    ///
133    /// # Details
134    ///
135    /// - Only applies filters for columns that exist and are considered safe
136    /// - Supports multiple operators: =, !=, >, >=, <, <=, IN, NOT IN, IS NULL, IS NOT NULL, LIKE, NOT LIKE
137    /// - Automatically handles type casting based on the database dialect
138    /// - Skips invalid columns with a warning when tracing is enabled
139    /// - For IN/NOT IN operators, comma-separated values are split into multiple parameters
140    ///
141    /// # Returns
142    ///
143    /// Returns self for method chaining
144    ///
145    /// # Example
146    ///
147    /// ```rust
148    /// use sqlx::Postgres;
149    /// use serde::{Serialize};
150    /// use sqlx_paginated::{QueryBuilder, QueryParamsBuilder, QueryFilterOperator};
151    ///
152    /// #[derive(Serialize, Default)]
153    /// struct Product {
154    ///     name: String,
155    ///     price: f64,
156    ///     stock: i32,
157    /// }
158    ///
159    /// let initial_params = QueryParamsBuilder::<Product>::new()
160    ///         .with_filter_operator("price", QueryFilterOperator::GreaterThan, "10.00")
161    ///         .with_filter_operator("stock", QueryFilterOperator::LessOrEqual, "100")
162    ///         .build();
163    ///
164    /// let query_builder = QueryBuilder::<Product, Postgres>::new()
165    ///     .with_filters(&initial_params)
166    ///     .build();
167    /// ```
168    pub fn with_filters(mut self, params: &'q QueryParams<T>) -> Self {
169        for (key, condition) in &params.filters {
170            if self.is_column_safe(key) {
171                self = self.apply_filter_condition(key, condition);
172            } else {
173                #[cfg(feature = "tracing")]
174                tracing::warn!(column = %key, "Skipping invalid filter column");
175            }
176        }
177        self
178    }
179
180    /// Applies a single filter condition to the query.
181    ///
182    /// This is a helper method that handles the SQL generation for different operators.
183    fn apply_filter_condition(mut self, column: &str, condition: &'q QueryFilterCondition) -> Self {
184        let table_column = self.dialect.quote_identifier(column);
185
186        match &condition.operator {
187            QueryFilterOperator::IsNull => {
188                self.conditions.push(format!("{} IS NULL", table_column));
189            }
190            QueryFilterOperator::IsNotNull => {
191                self.conditions
192                    .push(format!("{} IS NOT NULL", table_column));
193            }
194            QueryFilterOperator::In | QueryFilterOperator::NotIn => {
195                if let Some(_value) = &condition.value {
196                    let values = condition.split_values();
197                    if !values.is_empty() {
198                        let mut placeholders = Vec::new();
199                        for val in values {
200                            let next_argument = self.arguments.len() + 1;
201                            let placeholder = self.dialect.placeholder(next_argument);
202                            let type_cast = self.dialect.type_cast(&val);
203                            placeholders.push(format!("{}{}", placeholder, type_cast));
204                            self.arguments.add(val).unwrap_or_default();
205                        }
206
207                        let operator = condition.operator.to_sql();
208                        self.conditions.push(format!(
209                            "{} {} ({})",
210                            table_column,
211                            operator,
212                            placeholders.join(", ")
213                        ));
214                    }
215                }
216            }
217            QueryFilterOperator::Like | QueryFilterOperator::NotLike => {
218                if let Some(value) = &condition.value {
219                    let next_argument = self.arguments.len() + 1;
220                    let placeholder = self.dialect.placeholder(next_argument);
221                    let operator = condition.operator.to_sql();
222
223                    self.conditions.push(format!(
224                        "LOWER({}) {} LOWER({})",
225                        table_column, operator, placeholder
226                    ));
227                    self.arguments.add(value).unwrap_or_default();
228                }
229            }
230            _ => {
231                // Handle all comparison operators: =, !=, >, >=, <, <=
232                if let Some(value) = &condition.value {
233                    let next_argument = self.arguments.len() + 1;
234                    let placeholder = self.dialect.placeholder(next_argument);
235                    let type_cast = self.dialect.type_cast(value);
236                    let operator = condition.operator.to_sql();
237
238                    self.conditions.push(format!(
239                        "{} {} {}{}",
240                        table_column, operator, placeholder, type_cast
241                    ));
242                    self.arguments.add(value).unwrap_or_default();
243                }
244            }
245        }
246
247        self
248    }
249
250    /// Adds date range conditions to the query for a specified date column.
251    ///
252    /// # Arguments
253    ///
254    /// * `params` - Query parameters containing date range information
255    ///
256    /// # Type Parameters
257    ///
258    /// Requires `DateTime<Utc>` to be encodable for the target database
259    ///
260    /// # Details
261    ///
262    /// - Adds >= condition for date_after if specified
263    /// - Adds <= condition for date_before if specified
264    /// - Only applies to columns that exist and are considered safe
265    /// - Skips invalid date columns with a warning when tracing is enabled
266    ///
267    /// # Returns
268    ///
269    /// Returns self for method chaining
270    ///
271    /// # Example
272    ///
273    /// ```rust
274    /// use sqlx::Postgres;
275    /// use serde::{Serialize};
276    /// use chrono::{DateTime};
277    /// use sqlx_paginated::{QueryBuilder, QueryParamsBuilder, QueryParams};
278    ///
279    /// #[derive(Serialize, Default)]
280    /// struct UserExample {
281    ///     name: String
282    /// }
283    ///
284    /// let initial_params = QueryParamsBuilder::<UserExample>::new()
285    ///         .with_date_range(None, Some(DateTime::parse_from_rfc3339("2024-12-31T23:59:59Z").unwrap().into()), Some("deleted_at"))
286    ///         .build();
287    /// let query_builder = QueryBuilder::<UserExample, Postgres>::new()
288    ///     .with_date_range(&initial_params)
289    ///     .build();
290    /// ```
291    pub fn with_date_range(mut self, params: &'q QueryParams<T>) -> Self
292    where
293        DateTime<Utc>: for<'a> Encode<'a, DB> + Type<DB>,
294    {
295        if let Some(date_column) = &params.date_range.date_column {
296            if self.is_column_safe(date_column) {
297                if let Some(after) = params.date_range.date_after {
298                    let next_argument = self.arguments.len() + 1;
299                    let table_column = self.dialect.quote_identifier(date_column);
300                    let placeholder = self.dialect.placeholder(next_argument);
301                    self.conditions
302                        .push(format!("{} >= {}", table_column, placeholder));
303                    self.arguments.add(after).unwrap_or_default();
304                }
305
306                if let Some(before) = params.date_range.date_before {
307                    let next_argument = self.arguments.len() + 1;
308                    let table_column = self.dialect.quote_identifier(date_column);
309                    let placeholder = self.dialect.placeholder(next_argument);
310                    self.conditions
311                        .push(format!("{} <= {}", table_column, placeholder));
312                    self.arguments.add(before).unwrap_or_default();
313                }
314            } else {
315                #[cfg(feature = "tracing")]
316                tracing::warn!(column = %date_column, "Skipping invalid date column");
317            }
318        }
319
320        self
321    }
322
323    /// Adds a custom condition for a specific column with a provided operator and value.
324    ///
325    /// # Arguments
326    ///
327    /// * `column` - The column name to apply the condition to
328    /// * `condition` - The operator or condition to use (e.g., ">", "LIKE", etc.)
329    /// * `value` - The value to compare against
330    ///
331    /// # Details
332    ///
333    /// - Only applies to columns that exist and are considered safe
334    /// - Automatically handles parameter binding
335    /// - Skips invalid columns with a warning when tracing is enabled
336    ///
337    /// # Returns
338    ///
339    /// Returns self for method chaining
340    ///
341    /// # Example
342    ///
343    /// ```rust
344    /// use sqlx::Postgres;
345    /// use serde::{Serialize};
346    /// use sqlx_paginated::{QueryBuilder};
347    ///
348    /// #[derive(Serialize, Default)]
349    /// struct UserExample {
350    ///     name: String
351    /// }
352    ///
353    /// let query_builder = QueryBuilder::<UserExample, Postgres>::new()
354    ///     .with_condition("age", ">", "18".to_string())
355    ///     .build();
356    /// ```
357    pub fn with_condition(
358        mut self,
359        column: &str,
360        condition: impl Into<String>,
361        value: String,
362    ) -> Self {
363        if self.is_column_safe(column) {
364            let next_argument = self.arguments.len() + 1;
365            let table_column = self.dialect.quote_identifier(column);
366            let placeholder = self.dialect.placeholder(next_argument);
367            self.conditions.push(format!(
368                "{} {} {}",
369                table_column,
370                condition.into(),
371                placeholder
372            ));
373            let _ = self.arguments.add(value);
374        } else {
375            #[cfg(feature = "tracing")]
376            tracing::warn!(column = %column, "Skipping invalid condition column");
377        }
378        self
379    }
380
381    /// Adds a raw SQL condition to the query without any safety checks.
382    ///
383    /// # Arguments
384    ///
385    /// * `condition` - Raw SQL condition to add to the query
386    ///
387    /// # Safety
388    ///
389    /// This method bypasses column safety checks. Use with caution to prevent SQL injection.
390    ///
391    /// # Returns
392    ///
393    /// Returns self for method chaining
394    ///
395    /// # Example
396    ///
397    /// ```rust
398    /// use sqlx::Postgres;
399    /// use serde::{Serialize};
400    /// use sqlx_paginated::{QueryBuilder};
401    ///
402    /// #[derive(Serialize, Default)]
403    /// struct UserExample {
404    ///     name: String
405    /// }
406    ///
407    /// let query_builder = QueryBuilder::<UserExample, Postgres>::new()
408    ///     .with_raw_condition("status != 'deleted'")
409    ///     .build();
410    /// ```
411    pub fn with_raw_condition(mut self, condition: impl Into<String>) -> Self {
412        self.conditions.push(condition.into());
413        self
414    }
415
416    /// Allows adding multiple conditions using a closure.
417    ///
418    /// # Arguments
419    ///
420    /// * `f` - Closure that takes a mutable reference to the QueryBuilder
421    ///
422    /// # Details
423    ///
424    /// Useful for grouping multiple conditions that are logically related
425    ///
426    /// # Returns
427    ///
428    /// Returns self for method chaining
429    ///
430    /// # Example
431    ///
432    /// ```rust
433    /// use sqlx::Postgres;
434    /// use serde::{Serialize};
435    /// use sqlx_paginated::{QueryBuilder};
436    ///
437    /// #[derive(Serialize, Default)]
438    /// struct UserExample {
439    ///     name: String
440    /// }
441    /// let query_builder = QueryBuilder::<UserExample, Postgres>::new()
442    ///     .with_combined_conditions(|builder| {
443    ///         builder.conditions.push("status = 'active'".to_string());
444    ///         builder.conditions.push("age >= 18".to_string());
445    ///     })
446    ///     .build();
447    /// ```
448    pub fn with_combined_conditions<F>(mut self, f: F) -> Self
449    where
450        F: FnOnce(&mut QueryBuilder<T, DB>),
451    {
452        f(&mut self);
453        self
454    }
455
456    /// Disables column protection for this query builder instance.
457    ///
458    /// # Safety
459    ///
460    /// This removes all column safety checks. Use with caution as it may expose
461    /// the application to SQL injection if used with untrusted input.
462    ///
463    /// # Returns
464    ///
465    /// Returns self for method chaining
466    ///
467    /// # Example
468    ///
469    /// ```rust
470    /// use sqlx::Postgres;
471    /// use serde::{Serialize};
472    /// use sqlx_paginated::{QueryBuilder};
473    ///
474    /// #[derive(Serialize, Default)]
475    /// struct UserExample {
476    ///     name: String
477    /// }
478    ///
479    /// let query_builder = QueryBuilder::<UserExample, Postgres>::new()
480    ///     .disable_protection()
481    ///     .with_raw_condition("custom_column = 'value'")
482    ///     .build();
483    /// ```
484    pub fn disable_protection(mut self) -> Self {
485        self.protection_enabled = false;
486        self
487    }
488
489    /// Builds the final query conditions and arguments.
490    ///
491    /// # Returns
492    ///
493    /// Returns a tuple containing:
494    /// - Vec<String>: List of SQL conditions
495    /// - DB::Arguments: Database-specific arguments for parameter binding
496    ///
497    /// # Example
498    ///
499    /// ```rust
500    /// use sqlx::Postgres;
501    /// use serde::{Serialize};
502    /// use sqlx_paginated::{QueryBuilder, QueryParamsBuilder};
503    ///
504    /// #[derive(Serialize, Default)]
505    /// struct UserExample {
506    ///     name: String
507    /// }
508    ///
509    /// let initial_params = QueryParamsBuilder::<UserExample>::new()
510    ///         .with_search("john", vec!["name", "email"])
511    ///         .build();
512    /// let (conditions, arguments) = QueryBuilder::<UserExample, Postgres>::new()
513    ///     .with_search(&initial_params)
514    ///     .build();
515    /// ```
516    pub fn build(self) -> (Vec<String>, DB::Arguments<'q>) {
517        (self.conditions, self.arguments)
518    }
519}