sqlint/ast/
table.rs

1use super::{Column, Comparable, ConditionTree, DefaultValue, ExpressionKind, IndexDefinition, Join, JoinData};
2use crate::{
3    ast::{Expression, Row, Select, Values},
4    error::{Error, ErrorKind},
5};
6use std::borrow::Cow;
7
8/// An object that can be aliased.
9pub trait Aliasable<'a> {
10    type Target;
11
12    /// Alias table for usage elsewhere in the query.
13    fn alias<T>(self, alias: T) -> Self::Target
14    where
15        T: Into<Cow<'a, str>>;
16}
17
18#[derive(Clone, Debug, PartialEq)]
19/// Either an identifier or a nested query.
20pub enum TableType<'a> {
21    Table(Cow<'a, str>),
22    JoinedTable(Box<(Cow<'a, str>, Vec<Join<'a>>)>),
23    Query(Box<Select<'a>>),
24    Values(Values<'a>),
25}
26
27/// A table definition
28#[derive(Clone, Debug)]
29pub struct Table<'a> {
30    pub typ: TableType<'a>,
31    pub alias: Option<Cow<'a, str>>,
32    pub database: Option<Cow<'a, str>>,
33    pub(crate) index_definitions: Vec<IndexDefinition<'a>>,
34}
35
36impl<'a> PartialEq for Table<'a> {
37    fn eq(&self, other: &Table) -> bool {
38        self.typ == other.typ && self.database == other.database
39    }
40}
41
42impl<'a> Table<'a> {
43    /// Define in which database the table is located
44    pub fn database<T>(mut self, database: T) -> Self
45    where
46        T: Into<Cow<'a, str>>,
47    {
48        self.database = Some(database.into());
49        self
50    }
51
52    /// A qualified asterisk to this table
53    pub fn asterisk(self) -> Expression<'a> {
54        Expression { kind: ExpressionKind::Asterisk(Some(Box::new(self))), alias: None }
55    }
56
57    /// Add unique index definition.
58    pub fn add_unique_index(mut self, i: impl Into<IndexDefinition<'a>>) -> Self {
59        let definition = i.into();
60        self.index_definitions.push(definition.set_table(self.clone()));
61        self
62    }
63
64    /// Conditions for Microsoft T-SQL MERGE using the table metadata.
65    ///
66    /// - Find the unique indices from the table that matches the inserted columns
67    /// - Create a join from the virtual table with the uniques
68    /// - Combine joins with `OR`
69    /// - If the the index is a compound with other columns, combine them with `AND`
70    /// - If the column is not provided and index exists, try inserting a default value.
71    /// - Otherwise the function will return an error.
72    pub(crate) fn join_conditions(&self, inserted_columns: &[Column<'a>]) -> crate::Result<ConditionTree<'a>> {
73        let mut result = ConditionTree::NegativeCondition;
74
75        let join_cond = |column: &Column<'a>| {
76            let cond = if !inserted_columns.contains(column) {
77                match column.default.clone() {
78                    Some(DefaultValue::Provided(val)) => Some(column.clone().equals(val).into()),
79                    Some(DefaultValue::Generated) => None,
80                    None => {
81                        let kind =
82                            ErrorKind::conversion("A unique column missing from insert and table has no default.");
83
84                        return Err(Error::builder(kind).build());
85                    }
86                }
87            } else {
88                let dual_col = column.clone().table("dual");
89                Some(dual_col.equals(column.clone()).into())
90            };
91
92            Ok::<Option<ConditionTree>, Error>(cond)
93        };
94
95        for index in self.index_definitions.iter() {
96            match index {
97                IndexDefinition::Single(column) => {
98                    if let Some(right_cond) = join_cond(column)? {
99                        match result {
100                            ConditionTree::NegativeCondition => result = right_cond,
101                            left_cond => result = left_cond.or(right_cond),
102                        }
103                    }
104                }
105                IndexDefinition::Compound(cols) => {
106                    let mut sub_result = ConditionTree::NoCondition;
107
108                    for right in cols.iter() {
109                        let right_cond = join_cond(right)?.unwrap_or(ConditionTree::NegativeCondition);
110
111                        match sub_result {
112                            ConditionTree::NoCondition => sub_result = right_cond,
113                            left_cond => sub_result = left_cond.and(right_cond),
114                        }
115                    }
116
117                    match result {
118                        ConditionTree::NegativeCondition => result = sub_result,
119                        left_cond => result = left_cond.or(sub_result),
120                    }
121                }
122            }
123        }
124
125        Ok(result)
126    }
127
128    /// Adds a `LEFT JOIN` clause to the query, specifically for that table.
129    /// Useful to positionally add a JOIN clause in case you are selecting from multiple tables.
130    ///
131    /// ```rust
132    /// # use sqlint::{ast::*, visitor::{Visitor, Sqlite}};
133    /// # fn main() -> Result<(), sqlint::error::Error> {
134    /// let join = "posts".alias("p").on(("p", "visible").equals(true));
135    /// let joined_table = Table::from("users").left_join(join);
136    /// let query = Select::from_table(joined_table).and_from("comments");
137    /// let (sql, params) = Sqlite::build(query)?;
138    ///
139    /// assert_eq!(
140    ///     "SELECT `users`.*, `comments`.* FROM \
141    ///     `users` LEFT JOIN `posts` AS `p` ON `p`.`visible` = ?, \
142    ///     `comments`",
143    ///     sql
144    /// );
145    ///
146    /// assert_eq!(
147    ///     vec![
148    ///         Value::from(true),
149    ///     ],
150    ///     params
151    /// );
152    /// # Ok(())
153    /// # }
154    /// ```
155    pub fn left_join<J>(mut self, join: J) -> Self
156    where
157        J: Into<JoinData<'a>>,
158    {
159        match self.typ {
160            TableType::Table(table_name) => {
161                self.typ = TableType::JoinedTable(Box::new((table_name, vec![Join::Left(join.into())])))
162            }
163            TableType::JoinedTable(ref mut jt) => jt.1.push(Join::Left(join.into())),
164            TableType::Query(_) => {
165                panic!("You cannot left_join on a table of type Query")
166            }
167            TableType::Values(_) => {
168                panic!("You cannot left_join on a table of type Values")
169            }
170        }
171
172        self
173    }
174
175    /// Adds an `INNER JOIN` clause to the query, specifically for that table.
176    /// Useful to positionally add a JOIN clause in case you are selecting from multiple tables.
177    ///
178    /// ```rust
179    /// # use sqlint::{ast::*, visitor::{Visitor, Sqlite}};
180    /// # fn main() -> Result<(), sqlint::error::Error> {
181    /// let join = "posts".alias("p").on(("p", "visible").equals(true));
182    /// let joined_table = Table::from("users").inner_join(join);
183    /// let query = Select::from_table(joined_table).and_from("comments");
184    /// let (sql, params) = Sqlite::build(query)?;
185    ///
186    /// assert_eq!(
187    ///     "SELECT `users`.*, `comments`.* FROM \
188    ///     `users` INNER JOIN `posts` AS `p` ON `p`.`visible` = ?, \
189    ///     `comments`",
190    ///     sql
191    /// );
192    ///
193    /// assert_eq!(
194    ///     vec![
195    ///         Value::from(true),
196    ///     ],
197    ///     params
198    /// );
199    /// # Ok(())
200    /// # }
201    /// ```
202    pub fn inner_join<J>(mut self, join: J) -> Self
203    where
204        J: Into<JoinData<'a>>,
205    {
206        match self.typ {
207            TableType::Table(table_name) => {
208                self.typ = TableType::JoinedTable(Box::new((table_name, vec![Join::Inner(join.into())])))
209            }
210            TableType::JoinedTable(ref mut jt) => jt.1.push(Join::Inner(join.into())),
211            TableType::Query(_) => {
212                panic!("You cannot inner_join on a table of type Query")
213            }
214            TableType::Values(_) => {
215                panic!("You cannot inner_join on a table of type Values")
216            }
217        }
218
219        self
220    }
221
222    /// Adds a `RIGHT JOIN` clause to the query, specifically for that table.
223    /// Useful to positionally add a JOIN clause in case you are selecting from multiple tables.
224    ///
225    /// ```rust
226    /// # use sqlint::{ast::*, visitor::{Visitor, Sqlite}};
227    /// # fn main() -> Result<(), sqlint::error::Error> {
228    /// let join = "posts".alias("p").on(("p", "visible").equals(true));
229    /// let joined_table = Table::from("users").right_join(join);
230    /// let query = Select::from_table(joined_table).and_from("comments");
231    /// let (sql, params) = Sqlite::build(query)?;
232    ///
233    /// assert_eq!(
234    ///     "SELECT `users`.*, `comments`.* FROM \
235    ///     `users` RIGHT JOIN `posts` AS `p` ON `p`.`visible` = ?, \
236    ///     `comments`",
237    ///     sql
238    /// );
239    ///
240    /// assert_eq!(
241    ///     vec![
242    ///         Value::from(true),
243    ///     ],
244    ///     params
245    /// );
246    /// # Ok(())
247    /// # }
248    /// ```
249    pub fn right_join<J>(mut self, join: J) -> Self
250    where
251        J: Into<JoinData<'a>>,
252    {
253        match self.typ {
254            TableType::Table(table_name) => {
255                self.typ = TableType::JoinedTable(Box::new((table_name, vec![Join::Right(join.into())])))
256            }
257            TableType::JoinedTable(ref mut jt) => jt.1.push(Join::Right(join.into())),
258            TableType::Query(_) => {
259                panic!("You cannot right_join on a table of type Query")
260            }
261            TableType::Values(_) => {
262                panic!("You cannot right_join on a table of type Values")
263            }
264        }
265
266        self
267    }
268
269    /// Adds a `FULL JOIN` clause to the query, specifically for that table.
270    /// Useful to positionally add a JOIN clause in case you are selecting from multiple tables.
271    ///
272    /// ```rust
273    /// # use sqlint::{ast::*, visitor::{Visitor, Sqlite}};
274    /// # fn main() -> Result<(), sqlint::error::Error> {
275    /// let join = "posts".alias("p").on(("p", "visible").equals(true));
276    /// let joined_table = Table::from("users").full_join(join);
277    /// let query = Select::from_table(joined_table).and_from("comments");
278    /// let (sql, params) = Sqlite::build(query)?;
279    ///
280    /// assert_eq!(
281    ///     "SELECT `users`.*, `comments`.* FROM \
282    ///     `users` FULL JOIN `posts` AS `p` ON `p`.`visible` = ?, \
283    ///     `comments`",
284    ///     sql
285    /// );
286    ///
287    /// assert_eq!(
288    ///     vec![
289    ///         Value::from(true),
290    ///     ],
291    ///     params
292    /// );
293    /// # Ok(())
294    /// # }
295    /// ```
296    pub fn full_join<J>(mut self, join: J) -> Self
297    where
298        J: Into<JoinData<'a>>,
299    {
300        match self.typ {
301            TableType::Table(table_name) => {
302                self.typ = TableType::JoinedTable(Box::new((table_name, vec![Join::Full(join.into())])))
303            }
304            TableType::JoinedTable(ref mut jt) => jt.1.push(Join::Full(join.into())),
305            TableType::Query(_) => {
306                panic!("You cannot full_join on a table of type Query")
307            }
308            TableType::Values(_) => {
309                panic!("You cannot full_join on a table of type Values")
310            }
311        }
312
313        self
314    }
315}
316
317impl<'a> From<&'a str> for Table<'a> {
318    fn from(s: &'a str) -> Table<'a> {
319        Table { typ: TableType::Table(s.into()), alias: None, database: None, index_definitions: Vec::new() }
320    }
321}
322
323impl<'a> From<&'a String> for Table<'a> {
324    fn from(s: &'a String) -> Table<'a> {
325        Table { typ: TableType::Table(s.into()), alias: None, database: None, index_definitions: Vec::new() }
326    }
327}
328
329impl<'a> From<(&'a str, &'a str)> for Table<'a> {
330    fn from(s: (&'a str, &'a str)) -> Table<'a> {
331        let table: Table<'a> = s.1.into();
332        table.database(s.0)
333    }
334}
335
336impl<'a> From<(&'a str, &'a String)> for Table<'a> {
337    fn from(s: (&'a str, &'a String)) -> Table<'a> {
338        let table: Table<'a> = s.1.into();
339        table.database(s.0)
340    }
341}
342
343impl<'a> From<(&'a String, &'a str)> for Table<'a> {
344    fn from(s: (&'a String, &'a str)) -> Table<'a> {
345        let table: Table<'a> = s.1.into();
346        table.database(s.0)
347    }
348}
349
350impl<'a> From<(&'a String, &'a String)> for Table<'a> {
351    fn from(s: (&'a String, &'a String)) -> Table<'a> {
352        let table: Table<'a> = s.1.into();
353        table.database(s.0)
354    }
355}
356
357impl<'a> From<String> for Table<'a> {
358    fn from(s: String) -> Self {
359        Table { typ: TableType::Table(s.into()), alias: None, database: None, index_definitions: Vec::new() }
360    }
361}
362
363impl<'a> From<Vec<Row<'a>>> for Table<'a> {
364    fn from(values: Vec<Row<'a>>) -> Self {
365        Table::from(Values::from(values.into_iter()))
366    }
367}
368
369impl<'a> From<Values<'a>> for Table<'a> {
370    fn from(values: Values<'a>) -> Self {
371        Self { typ: TableType::Values(values), alias: None, database: None, index_definitions: Vec::new() }
372    }
373}
374
375impl<'a> From<(String, String)> for Table<'a> {
376    fn from(s: (String, String)) -> Table<'a> {
377        let table: Table<'a> = s.1.into();
378        table.database(s.0)
379    }
380}
381
382impl<'a> From<Select<'a>> for Table<'a> {
383    fn from(select: Select<'a>) -> Self {
384        Table { typ: TableType::Query(Box::new(select)), alias: None, database: None, index_definitions: Vec::new() }
385    }
386}
387
388impl<'a> Aliasable<'a> for Table<'a> {
389    type Target = Table<'a>;
390
391    fn alias<T>(mut self, alias: T) -> Self::Target
392    where
393        T: Into<Cow<'a, str>>,
394    {
395        self.alias = Some(alias.into());
396        self
397    }
398}
399
400aliasable!(String, (String, String));
401aliasable!(&'a str, (&'a str, &'a str));