Skip to main content

supabase_client_query/
sql.rs

1use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
2use serde_json::Value as JsonValue;
3use uuid::Uuid;
4
5/// Type-erased SQL parameter for dynamic query building.
6#[derive(Debug, Clone)]
7pub enum SqlParam {
8    Null,
9    Bool(bool),
10    I16(i16),
11    I32(i32),
12    I64(i64),
13    F32(f32),
14    F64(f64),
15    Text(String),
16    Uuid(Uuid),
17    Timestamp(NaiveDateTime),
18    TimestampTz(chrono::DateTime<chrono::Utc>),
19    Date(NaiveDate),
20    Time(NaiveTime),
21    Json(JsonValue),
22    ByteArray(Vec<u8>),
23    TextArray(Vec<String>),
24    I32Array(Vec<i32>),
25    I64Array(Vec<i64>),
26}
27
28/// Trait for converting Rust types into `SqlParam`.
29pub trait IntoSqlParam {
30    fn into_sql_param(self) -> SqlParam;
31}
32
33// Implementations for all common types
34
35impl IntoSqlParam for SqlParam {
36    fn into_sql_param(self) -> SqlParam {
37        self
38    }
39}
40
41impl IntoSqlParam for bool {
42    fn into_sql_param(self) -> SqlParam {
43        SqlParam::Bool(self)
44    }
45}
46
47impl IntoSqlParam for i16 {
48    fn into_sql_param(self) -> SqlParam {
49        SqlParam::I16(self)
50    }
51}
52
53impl IntoSqlParam for i32 {
54    fn into_sql_param(self) -> SqlParam {
55        SqlParam::I32(self)
56    }
57}
58
59impl IntoSqlParam for i64 {
60    fn into_sql_param(self) -> SqlParam {
61        SqlParam::I64(self)
62    }
63}
64
65impl IntoSqlParam for f32 {
66    fn into_sql_param(self) -> SqlParam {
67        SqlParam::F32(self)
68    }
69}
70
71impl IntoSqlParam for f64 {
72    fn into_sql_param(self) -> SqlParam {
73        SqlParam::F64(self)
74    }
75}
76
77impl IntoSqlParam for String {
78    fn into_sql_param(self) -> SqlParam {
79        SqlParam::Text(self)
80    }
81}
82
83impl IntoSqlParam for &str {
84    fn into_sql_param(self) -> SqlParam {
85        SqlParam::Text(self.to_string())
86    }
87}
88
89impl IntoSqlParam for Uuid {
90    fn into_sql_param(self) -> SqlParam {
91        SqlParam::Uuid(self)
92    }
93}
94
95impl IntoSqlParam for NaiveDateTime {
96    fn into_sql_param(self) -> SqlParam {
97        SqlParam::Timestamp(self)
98    }
99}
100
101impl IntoSqlParam for chrono::DateTime<chrono::Utc> {
102    fn into_sql_param(self) -> SqlParam {
103        SqlParam::TimestampTz(self)
104    }
105}
106
107impl IntoSqlParam for NaiveDate {
108    fn into_sql_param(self) -> SqlParam {
109        SqlParam::Date(self)
110    }
111}
112
113impl IntoSqlParam for NaiveTime {
114    fn into_sql_param(self) -> SqlParam {
115        SqlParam::Time(self)
116    }
117}
118
119impl IntoSqlParam for JsonValue {
120    fn into_sql_param(self) -> SqlParam {
121        SqlParam::Json(self)
122    }
123}
124
125impl IntoSqlParam for Vec<u8> {
126    fn into_sql_param(self) -> SqlParam {
127        SqlParam::ByteArray(self)
128    }
129}
130
131impl IntoSqlParam for Vec<String> {
132    fn into_sql_param(self) -> SqlParam {
133        SqlParam::TextArray(self)
134    }
135}
136
137impl IntoSqlParam for Vec<i32> {
138    fn into_sql_param(self) -> SqlParam {
139        SqlParam::I32Array(self)
140    }
141}
142
143impl IntoSqlParam for Vec<i64> {
144    fn into_sql_param(self) -> SqlParam {
145        SqlParam::I64Array(self)
146    }
147}
148
149impl<T: IntoSqlParam> IntoSqlParam for Option<T> {
150    fn into_sql_param(self) -> SqlParam {
151        match self {
152            Some(v) => v.into_sql_param(),
153            None => SqlParam::Null,
154        }
155    }
156}
157
158/// Store for collecting parameters during query building.
159#[derive(Debug, Clone, Default)]
160pub struct ParamStore {
161    params: Vec<SqlParam>,
162}
163
164impl ParamStore {
165    pub fn new() -> Self {
166        Self { params: Vec::new() }
167    }
168
169    /// Push a parameter and return its 1-based index (for `$N` placeholders).
170    pub fn push(&mut self, param: SqlParam) -> usize {
171        self.params.push(param);
172        self.params.len()
173    }
174
175    /// Push a value that implements IntoSqlParam.
176    pub fn push_value(&mut self, value: impl IntoSqlParam) -> usize {
177        self.push(value.into_sql_param())
178    }
179
180    /// Get a parameter by 0-based index.
181    pub fn get(&self, index: usize) -> Option<&SqlParam> {
182        self.params.get(index)
183    }
184
185    /// Get all parameters.
186    pub fn params(&self) -> &[SqlParam] {
187        &self.params
188    }
189
190    /// Consume and return all parameters.
191    pub fn into_params(self) -> Vec<SqlParam> {
192        self.params
193    }
194
195    /// Number of parameters stored.
196    pub fn len(&self) -> usize {
197        self.params.len()
198    }
199
200    pub fn is_empty(&self) -> bool {
201        self.params.is_empty()
202    }
203}
204
205// --- Filter types ---
206
207/// A single filter condition in a WHERE clause.
208#[derive(Debug, Clone)]
209pub enum FilterCondition {
210    /// column op $N (e.g. "name" = $1)
211    Comparison {
212        column: String,
213        operator: FilterOperator,
214        param_index: usize,
215    },
216    /// column IS NULL / IS NOT NULL / IS TRUE / IS FALSE
217    Is {
218        column: String,
219        value: IsValue,
220    },
221    /// column IN ($1, $2, ...)
222    In {
223        column: String,
224        param_indices: Vec<usize>,
225    },
226    /// column LIKE/ILIKE $N
227    Pattern {
228        column: String,
229        operator: PatternOperator,
230        param_index: usize,
231    },
232    /// Full-text search: column @@ to_tsquery(config, $N)
233    TextSearch {
234        column: String,
235        query_param_index: usize,
236        config: Option<String>,
237        search_type: TextSearchType,
238    },
239    /// Array/range operators (e.g. @>, <@, &&)
240    ArrayRange {
241        column: String,
242        operator: ArrayRangeOperator,
243        param_index: usize,
244    },
245    /// NOT (condition)
246    Not(Box<FilterCondition>),
247    /// (condition OR condition OR ...)
248    Or(Vec<FilterCondition>),
249    /// (condition AND condition AND ...) - used inside or_filter
250    And(Vec<FilterCondition>),
251    /// Raw SQL fragment (escape hatch)
252    Raw(String),
253    /// Match multiple column=value conditions (AND)
254    Match {
255        conditions: Vec<(String, usize)>,
256    },
257}
258
259#[derive(Debug, Clone, Copy, PartialEq, Eq)]
260pub enum FilterOperator {
261    Eq,
262    Neq,
263    Gt,
264    Gte,
265    Lt,
266    Lte,
267}
268
269impl FilterOperator {
270    pub fn as_sql(&self) -> &'static str {
271        match self {
272            Self::Eq => "=",
273            Self::Neq => "!=",
274            Self::Gt => ">",
275            Self::Gte => ">=",
276            Self::Lt => "<",
277            Self::Lte => "<=",
278        }
279    }
280}
281
282#[derive(Debug, Clone, Copy, PartialEq, Eq)]
283pub enum PatternOperator {
284    Like,
285    ILike,
286}
287
288impl PatternOperator {
289    pub fn as_sql(&self) -> &'static str {
290        match self {
291            Self::Like => "LIKE",
292            Self::ILike => "ILIKE",
293        }
294    }
295}
296
297#[derive(Debug, Clone, Copy, PartialEq, Eq)]
298pub enum IsValue {
299    Null,
300    NotNull,
301    True,
302    False,
303}
304
305impl IsValue {
306    pub fn as_sql(&self) -> &'static str {
307        match self {
308            Self::Null => "IS NULL",
309            Self::NotNull => "IS NOT NULL",
310            Self::True => "IS TRUE",
311            Self::False => "IS FALSE",
312        }
313    }
314}
315
316#[derive(Debug, Clone, Copy, PartialEq, Eq)]
317pub enum TextSearchType {
318    Plain,
319    Phrase,
320    Websearch,
321}
322
323impl TextSearchType {
324    pub fn function_name(&self) -> &'static str {
325        match self {
326            Self::Plain => "plainto_tsquery",
327            Self::Phrase => "phraseto_tsquery",
328            Self::Websearch => "websearch_to_tsquery",
329        }
330    }
331}
332
333#[derive(Debug, Clone, Copy, PartialEq, Eq)]
334pub enum ArrayRangeOperator {
335    Contains,
336    ContainedBy,
337    Overlaps,
338    RangeGt,
339    RangeGte,
340    RangeLt,
341    RangeLte,
342    RangeAdjacent,
343}
344
345impl ArrayRangeOperator {
346    pub fn as_sql(&self) -> &'static str {
347        match self {
348            Self::Contains => "@>",
349            Self::ContainedBy => "<@",
350            Self::Overlaps => "&&",
351            Self::RangeGt => ">>",
352            Self::RangeGte => "&>",   // in PostGIS/range context
353            Self::RangeLt => "<<",
354            Self::RangeLte => "&<",
355            Self::RangeAdjacent => "-|-",
356        }
357    }
358}
359
360// --- Order / Modifier types ---
361
362#[derive(Debug, Clone)]
363pub struct OrderClause {
364    pub column: String,
365    pub direction: OrderDirection,
366    pub nulls: Option<NullsPosition>,
367}
368
369#[derive(Debug, Clone, Copy, PartialEq, Eq)]
370pub enum OrderDirection {
371    Ascending,
372    Descending,
373}
374
375impl OrderDirection {
376    pub fn as_sql(&self) -> &'static str {
377        match self {
378            Self::Ascending => "ASC",
379            Self::Descending => "DESC",
380        }
381    }
382}
383
384#[derive(Debug, Clone, Copy, PartialEq, Eq)]
385pub enum NullsPosition {
386    First,
387    Last,
388}
389
390impl NullsPosition {
391    pub fn as_sql(&self) -> &'static str {
392        match self {
393            Self::First => "NULLS FIRST",
394            Self::Last => "NULLS LAST",
395        }
396    }
397}
398
399/// Count mode for responses.
400#[derive(Debug, Clone, Copy, PartialEq, Eq)]
401pub enum CountOption {
402    /// No count requested.
403    None,
404    /// Exact count via COUNT(*).
405    Exact,
406    /// Planned count from query planner (fast, approximate).
407    Planned,
408    /// Estimated count using statistics (fast, approximate).
409    Estimated,
410}
411
412// --- SQL Parts ---
413
414/// The type of SQL operation.
415#[derive(Debug, Clone, Copy, PartialEq, Eq)]
416pub enum SqlOperation {
417    Select,
418    Insert,
419    Update,
420    Delete,
421    Upsert,
422}
423
424/// Collects all the components of a SQL query being built.
425#[derive(Debug, Clone)]
426pub struct SqlParts {
427    pub operation: SqlOperation,
428    pub schema: String,
429    pub table: String,
430    /// Columns to select (None = *)
431    pub select_columns: Option<String>,
432    /// Filter conditions (WHERE)
433    pub filters: Vec<FilterCondition>,
434    /// ORDER BY clauses
435    pub orders: Vec<OrderClause>,
436    /// LIMIT
437    pub limit: Option<i64>,
438    /// OFFSET (from range)
439    pub offset: Option<i64>,
440    /// Whether to return a single row (enforced at execution)
441    pub single: bool,
442    /// Whether to return zero or one row
443    pub maybe_single: bool,
444    /// Count option
445    pub count: CountOption,
446    /// Insert/Update column-value pairs: Vec<(column, param_index)>
447    pub set_clauses: Vec<(String, usize)>,
448    /// For insert_many/upsert_many: Vec of rows, each is Vec<(column, param_index)>
449    pub many_rows: Vec<Vec<(String, usize)>>,
450    /// RETURNING columns (None = don't return, Some("*") = all)
451    pub returning: Option<String>,
452    /// ON CONFLICT columns (for upsert)
453    pub conflict_columns: Vec<String>,
454    /// ON CONFLICT constraint name (alternative to columns)
455    pub conflict_constraint: Option<String>,
456    /// When true, upsert generates ON CONFLICT DO NOTHING instead of DO UPDATE
457    pub ignore_duplicates: bool,
458    /// Schema override for per-query schema qualification
459    pub schema_override: Option<String>,
460    /// EXPLAIN options (only for SELECT)
461    pub explain: Option<ExplainOptions>,
462    /// Head mode: SELECT count(*) only, no rows
463    pub head: bool,
464}
465
466/// Options for the EXPLAIN modifier.
467#[derive(Debug, Clone)]
468pub struct ExplainOptions {
469    pub analyze: bool,
470    pub verbose: bool,
471    pub format: ExplainFormat,
472}
473
474impl Default for ExplainOptions {
475    fn default() -> Self {
476        Self {
477            analyze: true,
478            verbose: false,
479            format: ExplainFormat::Json,
480        }
481    }
482}
483
484/// Output format for EXPLAIN.
485#[derive(Debug, Clone, Copy, PartialEq, Eq)]
486pub enum ExplainFormat {
487    Text,
488    Json,
489    Xml,
490    Yaml,
491}
492
493impl ExplainFormat {
494    pub fn as_sql(&self) -> &'static str {
495        match self {
496            Self::Text => "TEXT",
497            Self::Json => "JSON",
498            Self::Xml => "XML",
499            Self::Yaml => "YAML",
500        }
501    }
502}
503
504impl SqlParts {
505    pub fn new(operation: SqlOperation, schema: impl Into<String>, table: impl Into<String>) -> Self {
506        Self {
507            operation,
508            schema: schema.into(),
509            table: table.into(),
510            select_columns: None,
511            filters: Vec::new(),
512            orders: Vec::new(),
513            limit: None,
514            offset: None,
515            single: false,
516            maybe_single: false,
517            count: CountOption::None,
518            set_clauses: Vec::new(),
519            many_rows: Vec::new(),
520            returning: None,
521            conflict_columns: Vec::new(),
522            conflict_constraint: None,
523            ignore_duplicates: false,
524            schema_override: None,
525            explain: None,
526            head: false,
527        }
528    }
529
530    /// Get the fully-qualified table name, using schema_override if set.
531    pub fn qualified_table(&self) -> String {
532        let schema = self.schema_override.as_deref().unwrap_or(&self.schema);
533        format!("\"{}\".\"{}\"", schema, self.table)
534    }
535}
536
537/// Validate that a column name is safe (no SQL injection).
538pub fn validate_column_name(name: &str) -> Result<(), supabase_client_core::SupabaseError> {
539    if name.is_empty() {
540        return Err(supabase_client_core::SupabaseError::query_builder(
541            "Column name cannot be empty",
542        ));
543    }
544    if name.contains('"') || name.contains(';') || name.contains("--") {
545        return Err(supabase_client_core::SupabaseError::query_builder(format!(
546            "Invalid column name: {name:?} (contains prohibited characters)"
547        )));
548    }
549    Ok(())
550}
551
552/// Validate a table or schema name.
553pub fn validate_identifier(name: &str, kind: &str) -> Result<(), supabase_client_core::SupabaseError> {
554    if name.is_empty() {
555        return Err(supabase_client_core::SupabaseError::query_builder(format!(
556            "{kind} name cannot be empty"
557        )));
558    }
559    if name.contains('"') || name.contains(';') || name.contains("--") {
560        return Err(supabase_client_core::SupabaseError::query_builder(format!(
561            "Invalid {kind} name: {name:?} (contains prohibited characters)"
562        )));
563    }
564    Ok(())
565}