sql_type/
lib.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12#![cfg_attr(not(test), no_std)]
13#![forbid(unsafe_code)]
14
15//! Crate for typing SQL statements.
16//!
17//! ```
18//! use sql_type::{schema::parse_schemas, type_statement, TypeOptions,
19//!     SQLDialect, SQLArguments, StatementType, Issues};
20//! let schemas = "
21//!     CREATE TABLE `events` (
22//!       `id` bigint(20) NOT NULL,
23//!       `user` int(11) NOT NULL,
24//!       `message` text NOT NULL
25//!     );";
26//!
27//! let mut issues = Issues::new(schemas);
28//!
29//! // Compute terse representation of the schemas
30//! let schemas = parse_schemas(schemas,
31//!     &mut issues,
32//!     &TypeOptions::new().dialect(SQLDialect::MariaDB));
33//! assert!(issues.is_ok());
34//!
35//! let sql = "SELECT `id`, `user`, `message` FROM `events` WHERE `id` = ?";
36//! let mut issues = Issues::new(sql);
37//! let stmt = type_statement(&schemas, sql, &mut issues,
38//!     &TypeOptions::new().dialect(SQLDialect::MariaDB).arguments(SQLArguments::QuestionMark));
39//! assert!(issues.is_ok());
40//!
41//! let stmt = match stmt {
42//!     StatementType::Select{columns, arguments} => {
43//!         assert_eq!(columns.len(), 3);
44//!         assert_eq!(arguments.len(), 1);
45//!     }
46//!     _ => panic!("Expected select statement")
47//! };
48//! ```
49
50extern crate alloc;
51
52use alloc::vec::Vec;
53use schema::Schemas;
54use sql_parse::{parse_statement, ParseOptions};
55pub use sql_parse::{Fragment, Issue, Issues, Level};
56
57mod type_;
58mod type_binary_expression;
59mod type_delete;
60mod type_expression;
61mod type_function;
62mod type_insert_replace;
63mod type_reference;
64mod type_select;
65mod type_statement;
66mod type_update;
67mod typer;
68
69pub mod schema;
70pub use type_::{BaseType, FullType, Type};
71pub use type_insert_replace::AutoIncrementId;
72pub use type_select::SelectTypeColumn;
73use typer::Typer;
74
75pub use sql_parse::{SQLArguments, SQLDialect};
76
77/// Options used when typing sql or parsing a schema
78#[derive(Debug, Default, Clone)]
79pub struct TypeOptions {
80    parse_options: ParseOptions,
81    warn_unnamed_column_in_select: bool,
82    warn_duplicate_column_in_select: bool,
83}
84
85impl TypeOptions {
86    /// Produce new default options
87    pub fn new() -> Self {
88        Default::default()
89    }
90
91    /// Change what sql dialect is used
92    pub fn dialect(self, dialect: SQLDialect) -> Self {
93        Self {
94            parse_options: self.parse_options.dialect(dialect),
95            ..self
96        }
97    }
98
99    /// Change how sql arguments are supplied
100    pub fn arguments(self, arguments: SQLArguments) -> Self {
101        Self {
102            parse_options: self.parse_options.arguments(arguments),
103            ..self
104        }
105    }
106
107    /// Should we warn about unquoted identifiers
108    pub fn warn_unquoted_identifiers(self, warn_unquoted_identifiers: bool) -> Self {
109        Self {
110            parse_options: self
111                .parse_options
112                .warn_unquoted_identifiers(warn_unquoted_identifiers),
113            ..self
114        }
115    }
116
117    /// Should we warn about keywords not in ALL CAPS
118    pub fn warn_none_capital_keywords(self, warn_none_capital_keywords: bool) -> Self {
119        Self {
120            parse_options: self
121                .parse_options
122                .warn_none_capital_keywords(warn_none_capital_keywords),
123            ..self
124        }
125    }
126
127    /// Should we warn about unnamed columns in selects
128    pub fn warn_unnamed_column_in_select(self, warn_unnamed_column_in_select: bool) -> Self {
129        Self {
130            warn_unnamed_column_in_select,
131            ..self
132        }
133    }
134
135    /// Should we warn about duplicate columns in selects
136    pub fn warn_duplicate_column_in_select(self, warn_duplicate_column_in_select: bool) -> Self {
137        Self {
138            warn_duplicate_column_in_select,
139            ..self
140        }
141    }
142
143    /// Parse _LIST_ as special expression and type as a list of items
144    pub fn list_hack(self, list_hack: bool) -> Self {
145        Self {
146            parse_options: self.parse_options.list_hack(list_hack),
147            ..self
148        }
149    }
150}
151
152/// Key of argument
153#[derive(Debug, Clone, Hash, PartialEq, Eq)]
154pub enum ArgumentKey<'a> {
155    /// Index of unnamed argument
156    Index(usize),
157    /// Name of named argument
158    Identifier(&'a str),
159}
160
161/// Type information of typed statement
162#[derive(Debug, Clone)]
163pub enum StatementType<'a> {
164    /// The statement was a select statement
165    Select {
166        /// The types and named of the columns return from the select
167        columns: Vec<SelectTypeColumn<'a>>,
168        /// The key and type of arguments to the query
169        arguments: Vec<(ArgumentKey<'a>, FullType<'a>)>,
170    },
171    /// The statement is a delete statement
172    Delete {
173        /// The key and type of arguments to the query
174        arguments: Vec<(ArgumentKey<'a>, FullType<'a>)>,
175        /// If present, the types and names of the columns returned from the delete
176        returning: Option<Vec<SelectTypeColumn<'a>>>,
177    },
178    /// The statement is an insert statement
179    Insert {
180        /// The insert happend in a table with a auto increment id row
181        yield_autoincrement: AutoIncrementId,
182        /// The key and type of arguments to the query
183        arguments: Vec<(ArgumentKey<'a>, FullType<'a>)>,
184        /// If present, the types and names of the columns returned from the insert
185        returning: Option<Vec<SelectTypeColumn<'a>>>,
186    },
187    /// The statement is a update statement
188    Update {
189        /// The key and type of arguments to the query
190        arguments: Vec<(ArgumentKey<'a>, FullType<'a>)>,
191    },
192    /// The statement is a replace statement
193    Replace {
194        /// The key and type of arguments to the query
195        arguments: Vec<(ArgumentKey<'a>, FullType<'a>)>,
196        /// If present, the types and names of the columns returned from the replace
197        returning: Option<Vec<SelectTypeColumn<'a>>>,
198    },
199    /// The query was not valid, errors are preset in issues
200    Invalid,
201}
202
203/// Type an sql statement with respect to a given schema
204pub fn type_statement<'a>(
205    schemas: &'a Schemas<'a>,
206    statement: &'a str,
207    issues: &mut Issues<'a>,
208    options: &TypeOptions,
209) -> StatementType<'a> {
210    if let Some(stmt) = parse_statement(statement, issues, &options.parse_options) {
211        let mut typer = Typer {
212            schemas,
213            issues,
214            reference_types: Vec::new(),
215            arg_types: Default::default(),
216            options,
217            with_schemas: Default::default(),
218        };
219        let t = type_statement::type_statement(&mut typer, &stmt);
220        let arguments = typer.arg_types;
221        match t {
222            type_statement::InnerStatementType::Select(s) => StatementType::Select {
223                columns: s.columns,
224                arguments,
225            },
226            type_statement::InnerStatementType::Delete { returning } => StatementType::Delete {
227                arguments,
228                returning: returning.map(|r| r.columns),
229            },
230            type_statement::InnerStatementType::Insert {
231                auto_increment_id,
232                returning,
233            } => StatementType::Insert {
234                yield_autoincrement: auto_increment_id,
235                arguments,
236                returning: returning.map(|r| r.columns),
237            },
238            type_statement::InnerStatementType::Update => StatementType::Update { arguments },
239            type_statement::InnerStatementType::Replace { returning } => StatementType::Replace {
240                arguments,
241                returning: returning.map(|r| r.columns),
242            },
243            type_statement::InnerStatementType::Invalid => StatementType::Invalid,
244        }
245    } else {
246        StatementType::Invalid
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use alloc::vec::Vec;
253    use codespan_reporting::{
254        diagnostic::{Diagnostic, Label},
255        files::SimpleFiles,
256        term::{
257            self,
258            termcolor::{ColorChoice, StandardStream},
259        },
260    };
261    use sql_parse::{Identifier, Issue, Issues, Level, SQLArguments, SQLDialect};
262
263    use crate::{
264        schema::parse_schemas, type_statement, ArgumentKey, AutoIncrementId, BaseType, FullType,
265        SelectTypeColumn, StatementType, Type, TypeOptions,
266    };
267
268    struct N<'a>(Option<&'a str>);
269    impl<'a> alloc::fmt::Display for N<'a> {
270        fn fmt(&self, f: &mut alloc::fmt::Formatter<'_>) -> alloc::fmt::Result {
271            if let Some(v) = self.0 {
272                v.fmt(f)
273            } else {
274                f.write_str("None")
275            }
276        }
277    }
278
279    struct N2<'a>(Option<Identifier<'a>>);
280    impl<'a> alloc::fmt::Display for N2<'a> {
281        fn fmt(&self, f: &mut alloc::fmt::Formatter<'_>) -> alloc::fmt::Result {
282            if let Some(v) = &self.0 {
283                v.fmt(f)
284            } else {
285                f.write_str("None")
286            }
287        }
288    }
289
290    fn check_no_errors(name: &str, src: &str, issues: &[Issue], errors: &mut usize) {
291        let mut files = SimpleFiles::new();
292        let file_id = files.add(name, &src);
293        let writer = StandardStream::stderr(ColorChoice::Always);
294        let config = codespan_reporting::term::Config::default();
295        for issue in issues {
296            let mut labels = vec![Label::primary(file_id, issue.span.clone())];
297            for fragment in &issue.fragments {
298                labels.push(
299                    Label::secondary(file_id, fragment.span.clone())
300                        .with_message(fragment.message.to_string()),
301                );
302            }
303            let d = match issue.level {
304                Level::Error => {
305                    *errors += 1;
306                    Diagnostic::error()
307                }
308                Level::Warning => Diagnostic::warning(),
309            };
310            let d = d
311                .with_message(issue.message.to_string())
312                .with_labels(labels);
313            term::emit(&mut writer.lock(), &config, &files, &d).unwrap();
314        }
315    }
316
317    fn str_to_type(t: &str) -> FullType<'static> {
318        let (t, not_null) = if let Some(t) = t.strip_suffix('!') {
319            (t, true)
320        } else {
321            (t, false)
322        };
323        let (t, list_hack) = if let Some(v) = t.strip_suffix("[]") {
324            (v, true)
325        } else {
326            (t, false)
327        };
328        let t = match t {
329            "b" => BaseType::Bool.into(),
330            "u8" => Type::U8,
331            "u16" => Type::U16,
332            "u32" => Type::U32,
333            "u64" => Type::U64,
334            "i8" => Type::I8,
335            "i16" => Type::I16,
336            "i32" => Type::I32,
337            "i64" => Type::I64,
338            "f32" => Type::F32,
339            "f64" => Type::F64,
340            "i" => BaseType::Integer.into(),
341            "f" => BaseType::Float.into(),
342            "str" => BaseType::String.into(),
343            "bytes" => BaseType::Bytes.into(),
344            "dt" => BaseType::DateTime.into(),
345            "json" => Type::JSON,
346            "any" => BaseType::Any.into(),
347            _ => panic!("Unknown type {}", t),
348        };
349        let mut t = FullType::new(t, not_null);
350        if list_hack {
351            t.list_hack = true;
352        }
353        t
354    }
355
356    fn check_arguments(
357        name: &str,
358        got: &[(ArgumentKey<'_>, FullType<'_>)],
359        expected: &str,
360        errors: &mut usize,
361    ) {
362        if expected.is_empty() {
363            for (cnt, value) in got.iter().enumerate() {
364                println!("{}: Unexpected argument {} type {:?}", name, cnt, value);
365                *errors += 1;
366            }
367            return;
368        }
369        let mut got2 = Vec::new();
370        let inv = FullType::invalid();
371        for (k, v) in got {
372            match k {
373                ArgumentKey::Index(i) => {
374                    while got2.len() <= *i {
375                        got2.push(&inv);
376                    }
377                    got2[*i] = v;
378                }
379                ArgumentKey::Identifier(k) => {
380                    println!("{}: Got named argument {}", name, k);
381                    *errors += 1;
382                }
383            }
384        }
385        let mut cnt = 0;
386        for (i, t) in expected.split(',').enumerate() {
387            let t = t.trim();
388            let t = str_to_type(t);
389            if let Some(v) = got2.get(i) {
390                if *v != &t {
391                    println!("{}: Expected type {} for argument {} got {}", name, t, i, v);
392                    *errors += 1;
393                }
394            } else {
395                println!("{}: Expected type {} for argument {} got None", name, t, i);
396                *errors += 1;
397            }
398            cnt += 1;
399        }
400        while cnt < got.len() {
401            println!("{}: Unexpected argument {} type {:?}", name, cnt, got[cnt]);
402            cnt += 1;
403            *errors += 1;
404        }
405    }
406
407    fn check_columns(name: &str, got: &[SelectTypeColumn<'_>], expected: &str, errors: &mut usize) {
408        let mut cnt = 0;
409        for (i, t) in expected.split(',').enumerate() {
410            let t = t.trim();
411            let (cname, t) = t.split_once(":").unwrap();
412            let t = str_to_type(t);
413            let cname = if cname.is_empty() { None } else { Some(cname) };
414            if let Some(v) = got.get(i) {
415                if v.name.as_deref() != cname || v.type_ != t {
416                    println!(
417                        "{}: Expected column {} with name {} of type {} got {} of type {}",
418                        name,
419                        i,
420                        N(cname),
421                        t,
422                        N2(v.name.clone()),
423                        v.type_
424                    );
425                    *errors += 1;
426                }
427            } else {
428                println!(
429                    "{}: Expected column {} with name {} of type {} got None",
430                    name,
431                    i,
432                    N(cname),
433                    t
434                );
435                *errors += 1;
436            }
437            cnt += 1;
438        }
439        while cnt < got.len() {
440            println!(
441                "{}: Unexpected column {} with name {} of type {}",
442                name,
443                cnt,
444                N2(got[cnt].name.clone()),
445                got[cnt].type_
446            );
447            cnt += 1;
448            *errors += 1;
449        }
450    }
451
452    #[test]
453    fn mariadb() {
454        let schema_src = "
455
456        DROP TABLE IF EXISTS `t1`;
457        CREATE TABLE `t1` (
458          `id` int(11) NOT NULL,
459          `cbool` tinyint(1) NOT NULL,
460          `cu8` tinyint UNSIGNED NOT NULL,
461          `cu16` smallint UNSIGNED NOT NULL,
462          `cu32` int UNSIGNED NOT NULL,
463          `cu64` bigint UNSIGNED NOT NULL,
464          `ci8` tinyint,
465          `ci16` smallint,
466          `ci32` int,
467          `ci64` bigint,
468          `cbin` binary(16),
469          `ctext` varchar(100) NOT NULL,
470          `cbytes` blob,
471          `cf32` float,
472          `cf64` double,
473          `cu8_plus_one` tinyint UNSIGNED GENERATED ALWAYS AS (
474            `cu8` + 1
475           ) STORED,
476          `status` varchar(10) GENERATED ALWAYS AS (case when `cu8` <> 0 and `cu16` = 0 then 'a' when
477            `cbool` then 'b' when `ci32` = 42 then 'd' when `cu64` = 43 then 'x' when
478            `ci64` = 12 then 'y' else 'z' end) VIRTUAL
479        ) ENGINE=InnoDB DEFAULT CHARSET=utf8;
480
481        ALTER TABLE `t1`
482          MODIFY `id` int(11) NOT NULL AUTO_INCREMENT;
483
484        DROP INDEX IF EXISTS `hat` ON `t1`;
485
486        CREATE INDEX `hat2` ON `t1` (`id`, `cf64`);
487
488        CREATE TABLE `t2` (
489          `id` int(11) NOT NULL AUTO_INCREMENT,
490          `t1_id` int(11) NOT NULL);
491
492        CREATE TABLE `t3` (
493            `id` int(11) NOT NULL AUTO_INCREMENT,
494            `text` TEXT);
495
496        CREATE TABLE `t4` (
497            `id` int(11) NOT NULL AUTO_INCREMENT,
498            `dt` datetime NOT NULL);
499        ";
500
501        let options = TypeOptions::new().dialect(SQLDialect::MariaDB);
502        let mut issues = Issues::new(schema_src);
503        let schema = parse_schemas(schema_src, &mut issues, &options);
504        let mut errors = 0;
505        check_no_errors("schema", schema_src, issues.get(), &mut errors);
506
507        let options = TypeOptions::new()
508            .dialect(SQLDialect::MariaDB)
509            .arguments(SQLArguments::QuestionMark);
510
511        {
512            let name = "q1";
513            let src =
514                "SELECT `id`, `cbool`, `cu8`, `cu8_plus_one`, `cu16`, `cu32`, `cu64`, `ci8`, `ci16`, `ci32`, `ci64`,
515                `ctext`, `cbytes`, `cf32`, `cf64` FROM `t1` WHERE ci8 IS NOT NULL
516                AND `cbool`=? AND `cu8`=? AND `cu16`=? AND `cu32`=? AND `cu64`=?
517                AND `ci8`=? AND `ci16`=? AND `ci32`=? AND `ci64`=?
518                AND `ctext`=? AND `cbytes`=? AND `cf32`=? AND `cf64`=?";
519
520            let mut issues: Issues<'_> = Issues::new(src);
521            let q = type_statement(&schema, src, &mut issues, &options);
522            check_no_errors(name, src, issues.get(), &mut errors);
523            if let StatementType::Select { arguments, columns } = q {
524                check_arguments(
525                    name,
526                    &arguments,
527                    "b,i,i,i,i,i,i,i,i,str,bytes,f,f",
528                    &mut errors,
529                );
530                check_columns(
531                    name,
532                    &columns,
533                    "id:i32!,cbool:b!,cu8:u8!,cu8_plus_one:u8!,cu16:u16!,cu32:u32!,cu64:u64!,
534                    ci8:i8!,ci16:i16!,ci32:i32!,ci64:i64!,ctext:str!,cbytes:bytes!,cf32:f32!,cf64:f64!",
535                    &mut errors,
536                );
537            } else {
538                println!("{} should be select", name);
539                errors += 1;
540            }
541        }
542
543        {
544            let name = "q1.1";
545            let src =
546                "SELECT `id`, `cbool`, `cu8`, `cu16`, `cu32`, `cu64`, `ci8`, `ci16`, `ci32`, `ci64`,
547                `ctext`, `cbytes`, `cf32`, `cf64`, `cbin` FROM `t1` WHERE ci8 IS NOT NULL";
548
549            let mut issues: Issues<'_> = Issues::new(src);
550            let q = type_statement(&schema, src, &mut issues, &options);
551            check_no_errors(name, src, issues.get(), &mut errors);
552            if let StatementType::Select { arguments, columns } = q {
553                check_arguments(name, &arguments, "", &mut errors);
554                check_columns(
555                    name,
556                    &columns,
557                    "id:i32!,cbool:b!,cu8:u8!,cu16:u16!,cu32:u32!,cu64:u64!,
558                    ci8:i8!,ci16:i16,ci32:i32,ci64:i64,ctext:str!,cbytes:bytes,cf32:f32,cf64:f64,cbin:bytes",
559                    &mut errors,
560                );
561            } else {
562                println!("{} should be select", name);
563                errors += 1;
564            }
565        }
566
567        {
568            let name = "q2";
569            let src =
570            "INSERT INTO `t1` (`cbool`, `cu8`, `cu16`, `cu32`, `cu64`, `ci8`, `ci16`, `ci32`, `ci64`,
571            `ctext`, `cbytes`, `cf32`, `cf64`) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
572
573            let mut issues: Issues<'_> = Issues::new(src);
574            let q = type_statement(&schema, src, &mut issues, &options);
575            check_no_errors(name, src, issues.get(), &mut errors);
576            if let StatementType::Insert {
577                arguments,
578                yield_autoincrement,
579                returning,
580            } = q
581            {
582                check_arguments(
583                    name,
584                    &arguments,
585                    "b!,u8!,u16!,u32!,u64!,i8,i16,i32,i64,str!,bytes,f32,f64",
586                    &mut errors,
587                );
588                if yield_autoincrement != AutoIncrementId::Yes {
589                    println!("{} should yield autoincrement", name);
590                    errors += 1;
591                }
592                if returning.is_some() {
593                    println!("{} should not return columns", name);
594                    errors += 1;
595                }
596            } else {
597                println!("{} should be insert", name);
598                errors += 1;
599            }
600        }
601
602        {
603            let name = "q3";
604            let src =
605                "DELETE `t1` FROM `t1`, `t2` WHERE `t1`.`id` = `t2`.`t1_id` AND `t2`.`id` = ?";
606            let mut issues: Issues<'_> = Issues::new(src);
607            let q = type_statement(&schema, src, &mut issues, &options);
608            check_no_errors(name, src, issues.get(), &mut errors);
609            if let StatementType::Delete { arguments, .. } = q {
610                check_arguments(name, &arguments, "i", &mut errors);
611            } else {
612                println!("{} should be delete", name);
613                errors += 1;
614            }
615        }
616
617        {
618            let name = "q4";
619            let src = "INSERT INTO `t2` (`t1_id`) VALUES (?) ON DUPLICATE KEY UPDATE `t1_id`=?";
620            let mut issues: Issues<'_> = Issues::new(src);
621            let q = type_statement(&schema, src, &mut issues, &options);
622            check_no_errors(name, src, issues.get(), &mut errors);
623            if let StatementType::Insert {
624                arguments,
625                yield_autoincrement,
626                returning,
627            } = q
628            {
629                check_arguments(name, &arguments, "i32!,i32!", &mut errors);
630                if yield_autoincrement != AutoIncrementId::Optional {
631                    println!("{} should yield optional auto increment", name);
632                    errors += 1;
633                }
634                if returning.is_some() {
635                    println!("{} should not return columns", name);
636                    errors += 1;
637                }
638            } else {
639                println!("{} should be insert", name);
640                errors += 1;
641            }
642        }
643
644        {
645            let name = "q5";
646            let src = "INSERT IGNORE INTO `t2` SET `t1_id`=?";
647            let mut issues: Issues<'_> = Issues::new(src);
648            let q = type_statement(&schema, src, &mut issues, &options);
649            check_no_errors(name, src, issues.get(), &mut errors);
650            if let StatementType::Insert {
651                arguments,
652                yield_autoincrement,
653                returning,
654            } = q
655            {
656                check_arguments(name, &arguments, "i32!", &mut errors);
657                if yield_autoincrement != AutoIncrementId::Optional {
658                    println!("{} should yield optional auto increment", name);
659                    errors += 1;
660                }
661                if returning.is_some() {
662                    println!("{} should not return columns", name);
663                    errors += 1;
664                }
665            } else {
666                println!("{} should be insert", name);
667                errors += 1;
668            }
669        }
670
671        {
672            let name = "q6";
673            let src = "SELECT IF(`ci32` IS NULL, `cbool`, ?) AS `cc` FROM `t1`";
674            let mut issues: Issues<'_> = Issues::new(src);
675            let q = type_statement(&schema, src, &mut issues, &options);
676            check_no_errors(name, src, issues.get(), &mut errors);
677            if let StatementType::Select { arguments, columns } = q {
678                check_arguments(name, &arguments, "b", &mut errors);
679                check_columns(name, &columns, "cc:b", &mut errors);
680            } else {
681                println!("{} should be select", name);
682                errors += 1;
683            }
684        }
685
686        {
687            let name = "q7";
688            let src = "SELECT FROM_UNIXTIME(CAST(UNIX_TIMESTAMP() AS DOUBLE)) AS `cc` FROM `t1` WHERE `id`=?";
689            let mut issues: Issues<'_> = Issues::new(src);
690            let q = type_statement(&schema, src, &mut issues, &options);
691            check_no_errors(name, src, issues.get(), &mut errors);
692            if let StatementType::Select { arguments, columns } = q {
693                check_arguments(name, &arguments, "i", &mut errors);
694                check_columns(name, &columns, "cc:dt!", &mut errors);
695            } else {
696                println!("{} should be select", name);
697                errors += 1;
698            }
699        }
700
701        {
702            let name = "q8";
703            let src = "REPLACE INTO `t2` SET `id` = ?, `t1_id`=?";
704            let mut issues: Issues<'_> = Issues::new(src);
705            let q = type_statement(&schema, src, &mut issues, &options);
706            check_no_errors(name, src, issues.get(), &mut errors);
707            if let StatementType::Replace {
708                arguments,
709                returning,
710            } = q
711            {
712                check_arguments(name, &arguments, "i32!,i32!", &mut errors);
713                if returning.is_some() {
714                    println!("{} should not return columns", name);
715                    errors += 1;
716                }
717            } else {
718                println!("{} should be replace", name);
719                errors += 1;
720            }
721        }
722
723        {
724            let name = "q9";
725            let src = "INSERT INTO `t2` (`t1_id`) VALUES (32) ON DUPLICATE KEY UPDATE `t1_id` = `t1_id` + VALUES(`t1_id`)";
726            let mut issues: Issues<'_> = Issues::new(src);
727            let q = type_statement(&schema, src, &mut issues, &options);
728            check_no_errors(name, src, issues.get(), &mut errors);
729            if let StatementType::Insert { arguments, .. } = q {
730                check_arguments(name, &arguments, "", &mut errors);
731            } else {
732                println!("{} should be insert", name);
733                errors += 1;
734            }
735        }
736
737        {
738            let name = "q10";
739            let src =
740                "SELECT SUBSTRING_INDEX(`text`, '/', 5) AS `k` FROM `t3` WHERE `text` LIKE '%T%'";
741            let mut issues: Issues<'_> = Issues::new(src);
742            let q = type_statement(&schema, src, &mut issues, &options);
743            check_no_errors(name, src, issues.get(), &mut errors);
744            if let StatementType::Select { arguments, columns } = q {
745                check_arguments(name, &arguments, "", &mut errors);
746                check_columns(name, &columns, "k:str!", &mut errors);
747            } else {
748                println!("{} should be select", name);
749                errors += 1;
750            }
751        }
752
753        {
754            let name = "q11";
755            let src = "SELECT * FROM `t1`, `t2` LEFT JOIN `t3` ON `t3`.`id` = `t1`.`id`";
756            let mut issues: Issues<'_> = Issues::new(src);
757            type_statement(&schema, src, &mut issues, &options);
758            if !issues.get().iter().any(|i| i.level == Level::Error) {
759                println!("{} should be an error", name);
760                errors += 1;
761            }
762        }
763
764        {
765            let name = "q12";
766            let src =
767                "SELECT JSON_REPLACE('{ \"A\": 1, \"B\": [2, 3]}', '$.B[1]', 4, '$.C[3]', 3) AS `k` FROM `t3`";
768            let mut issues: Issues<'_> = Issues::new(src);
769            let q = type_statement(&schema, src, &mut issues, &options);
770            check_no_errors(name, src, issues.get(), &mut errors);
771            if let StatementType::Select { arguments, columns } = q {
772                check_arguments(name, &arguments, "", &mut errors);
773                check_columns(name, &columns, "k:json", &mut errors);
774            } else {
775                println!("{} should be select", name);
776                errors += 1;
777            }
778        }
779
780        {
781            let options = options.clone().list_hack(true);
782            let name = "q13";
783            let src = "SELECT `id` FROM `t1` WHERE `id` IN (_LIST_)";
784            let mut issues: Issues<'_> = Issues::new(src);
785            let q = type_statement(&schema, src, &mut issues, &options);
786            check_no_errors(name, src, issues.get(), &mut errors);
787            if let StatementType::Select { arguments, columns } = q {
788                check_arguments(name, &arguments, "i[]", &mut errors);
789                check_columns(name, &columns, "id:i32!", &mut errors);
790            } else {
791                println!("{} should be select", name);
792                errors += 1;
793            }
794        }
795
796        {
797            let name = "q14";
798            let src = "SELECT CAST(NULL AS CHAR) AS `id`";
799            let mut issues: Issues<'_> = Issues::new(src);
800            let q = type_statement(&schema, src, &mut issues, &options);
801            check_no_errors(name, src, issues.get(), &mut errors);
802            if let StatementType::Select { arguments, columns } = q {
803                check_arguments(name, &arguments, "", &mut errors);
804                check_columns(name, &columns, "id:str", &mut errors);
805            } else {
806                println!("{} should be select", name);
807                errors += 1;
808            }
809        }
810
811        {
812            let name = "q15";
813            let src =
814				"INSERT INTO `t1` (`cbool`, `cu8`, `cu16`, `cu32`, `cu64`, `ci8`, `ci16`, `ci32`, `ci64`,
815            `ctext`, `cbytes`, `cf32`, `cf64`) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
816                 RETURNING `id`, `cbool`, `cu8`, `ctext`, `cf64`";
817            let mut issues: Issues<'_> = Issues::new(src);
818            let q = type_statement(&schema, src, &mut issues, &options);
819            check_no_errors(name, src, issues.get(), &mut errors);
820            if let StatementType::Insert {
821                arguments,
822                yield_autoincrement,
823                returning,
824            } = q
825            {
826                check_arguments(
827                    name,
828                    &arguments,
829                    "b!,u8!,u16!,u32!,u64!,i8,i16,i32,i64,str!,bytes,f32,f64",
830                    &mut errors,
831                );
832                if yield_autoincrement != AutoIncrementId::Yes {
833                    println!("{} should yield autoincrement", name);
834                    errors += 1;
835                }
836                if let Some(returning) = returning {
837                    check_columns(
838                        name,
839                        &returning,
840                        "id:i32!,cbool:b!,cu8:u8!,ctext:str!,cf64:f64",
841                        &mut errors,
842                    );
843                } else {
844                    println!("{} should return columns", name);
845                    errors += 1;
846                }
847            } else {
848                println!("{} should be insert", name);
849                errors += 1;
850            }
851        }
852
853        {
854            let name = "q16";
855            let src = "REPLACE INTO `t2` SET `id` = ?, `t1_id`=? RETURNING `id`";
856            let mut issues: Issues<'_> = Issues::new(src);
857            let q = type_statement(&schema, src, &mut issues, &options);
858            check_no_errors(name, src, issues.get(), &mut errors);
859            if let StatementType::Replace {
860                arguments,
861                returning,
862            } = q
863            {
864                check_arguments(name, &arguments, "i32!,i32!", &mut errors);
865                if let Some(returning) = returning {
866                    check_columns(name, &returning, "id:i32!", &mut errors);
867                } else {
868                    println!("{} should return columns", name);
869                    errors += 1;
870                }
871            } else {
872                println!("{} should be replace", name);
873                errors += 1;
874            }
875        }
876
877        {
878            let name = "q17";
879            let src = "SELECT dt, UNIX_TIMESTAMP(dt) AS t FROM t4";
880            let mut issues: Issues<'_> = Issues::new(src);
881            let q = type_statement(&schema, src, &mut issues, &options);
882            check_no_errors(name, src, issues.get(), &mut errors);
883            if let StatementType::Select { arguments, columns } = q {
884                check_arguments(name, &arguments, "", &mut errors);
885                check_columns(name, &columns, "dt:dt!,t:i64!", &mut errors);
886            } else {
887                println!("{} should be select", name);
888                errors += 1;
889            }
890        }
891
892        {
893            let name = "q17";
894            let src = "SELECT CONCAT(?, \"hat\") AS c";
895            let mut issues: Issues<'_> = Issues::new(src);
896            let q = type_statement(&schema, src, &mut issues, &options);
897            check_no_errors(name, src, issues.get(), &mut errors);
898            if let StatementType::Select { arguments, columns } = q {
899                check_arguments(name, &arguments, "any", &mut errors);
900                check_columns(name, &columns, "c:str", &mut errors);
901            } else {
902                println!("{} should be selsect", name);
903                errors += 1;
904            }
905        }
906
907        {
908            let name = "q18";
909            let src = "SELECT CAST(\"::0\" AS INET6) AS `id`";
910            let mut issues: Issues<'_> = Issues::new(src);
911            let q = type_statement(&schema, src, &mut issues, &options);
912            check_no_errors(name, src, issues.get(), &mut errors);
913            if let StatementType::Select { arguments, columns } = q {
914                check_arguments(name, &arguments, "", &mut errors);
915                check_columns(name, &columns, "id:str!", &mut errors);
916            } else {
917                println!("{} should be select", name);
918                errors += 1;
919            }
920        }
921
922        {
923            let name: &str = "q18";
924            let src = "SELECT SUBSTRING(`cbytes`, 1, 5) AS `k` FROM `t1`";
925            let mut issues: Issues<'_> = Issues::new(src);
926            let q = type_statement(&schema, src, &mut issues, &options);
927            check_no_errors(name, src, issues.get(), &mut errors);
928            if let StatementType::Select { arguments, columns } = q {
929                check_arguments(name, &arguments, "", &mut errors);
930                check_columns(name, &columns, "k:bytes", &mut errors);
931            } else {
932                println!("{} should be select", name);
933                errors += 1;
934            }
935        }
936
937        {
938            let name = "q19";
939            let src = "SELECT SUBSTRING(`ctext`, 1, 5) AS `k` FROM `t1`";
940            let mut issues: Issues<'_> = Issues::new(src);
941            let q = type_statement(&schema, src, &mut issues, &options);
942            check_no_errors(name, src, issues.get(), &mut errors);
943            if let StatementType::Select { arguments, columns } = q {
944                check_arguments(name, &arguments, "", &mut errors);
945                check_columns(name, &columns, "k:str!", &mut errors);
946            } else {
947                println!("{} should be select", name);
948                errors += 1;
949            }
950        }
951
952        {
953            let name = "q19";
954            let src = "SELECT SUBSTRING(`ctext`, 1, 5) AS `k` FROM `t1`";
955            let mut issues: Issues<'_> = Issues::new(src);
956            let q = type_statement(&schema, src, &mut issues, &options);
957            check_no_errors(name, src, issues.get(), &mut errors);
958            if let StatementType::Select { arguments, columns } = q {
959                check_arguments(name, &arguments, "", &mut errors);
960                check_columns(name, &columns, "k:str!", &mut errors);
961            } else {
962                println!("{} should be select", name);
963                errors += 1;
964            }
965        }
966
967        {
968            let name = "q20";
969            let src = "SELECT JSON_QUERY('{ \"A\": 1, \"B\": [2, 3]}', '$.B[1]') AS `k` FROM `t3`";
970            let mut issues: Issues<'_> = Issues::new(src);
971            let q = type_statement(&schema, src, &mut issues, &options);
972            check_no_errors(name, src, issues.get(), &mut errors);
973            if let StatementType::Select { arguments, columns } = q {
974                check_arguments(name, &arguments, "", &mut errors);
975                check_columns(name, &columns, "k:json", &mut errors);
976            } else {
977                println!("{} should be select", name);
978                errors += 1;
979            }
980        }
981
982        {
983            let name = "q21";
984            let src =
985                "SELECT JSON_REMOVE('{ \"A\": 1, \"B\": [2, 3]}', '$.B[1]', '$.C[3]') AS `k` FROM `t3`";
986            let mut issues: Issues<'_> = Issues::new(src);
987            let q = type_statement(&schema, src, &mut issues, &options);
988            check_no_errors(name, src, issues.get(), &mut errors);
989            if let StatementType::Select { arguments, columns } = q {
990                check_arguments(name, &arguments, "", &mut errors);
991                check_columns(name, &columns, "k:json", &mut errors);
992            } else {
993                println!("{} should be select", name);
994                errors += 1;
995            }
996        }
997
998        {
999            let name = "q22";
1000            let src = "SELECT JSON_OVERLAPS('false', 'false') AS `k` FROM `t3`";
1001            let mut issues: Issues<'_> = Issues::new(src);
1002            let q = type_statement(&schema, src, &mut issues, &options);
1003            check_no_errors(name, src, issues.get(), &mut errors);
1004            if let StatementType::Select { arguments, columns } = q {
1005                check_arguments(name, &arguments, "", &mut errors);
1006                check_columns(name, &columns, "k:b!", &mut errors);
1007            } else {
1008                println!("{} should be select", name);
1009                errors += 1;
1010            }
1011        }
1012
1013        {
1014            let name = "q23";
1015            let src = "SELECT JSON_OVERLAPS('false', NULL) AS `k` FROM `t3`";
1016            let mut issues: Issues<'_> = Issues::new(src);
1017            let q = type_statement(&schema, src, &mut issues, &options);
1018            check_no_errors(name, src, issues.get(), &mut errors);
1019            if let StatementType::Select { arguments, columns } = q {
1020                check_arguments(name, &arguments, "", &mut errors);
1021                check_columns(name, &columns, "k:b", &mut errors);
1022            } else {
1023                println!("{} should be select", name);
1024                errors += 1;
1025            }
1026        }
1027
1028        {
1029            let name = "q24";
1030            let src =
1031                "SELECT JSON_CONTAINS('{\"A\": 0, \"B\": [\"x\", \"y\"]}', '\"x\"', '$.B') AS `k` FROM `t3`";
1032            let mut issues: Issues<'_> = Issues::new(src);
1033            let q = type_statement(&schema, src, &mut issues, &options);
1034            check_no_errors(name, src, issues.get(), &mut errors);
1035            if let StatementType::Select { arguments, columns } = q {
1036                check_arguments(name, &arguments, "", &mut errors);
1037                check_columns(name, &columns, "k:b!", &mut errors);
1038            } else {
1039                println!("{} should be select", name);
1040                errors += 1;
1041            }
1042        }
1043
1044        {
1045            let name = "q25";
1046            let src =
1047                "SELECT JSON_CONTAINS('{\"A\": 0, \"B\": [\"x\", \"y\"]}', NULL, '$.A') AS `k` FROM `t3`";
1048            let mut issues: Issues<'_> = Issues::new(src);
1049            let q = type_statement(&schema, src, &mut issues, &options);
1050            check_no_errors(name, src, issues.get(), &mut errors);
1051            if let StatementType::Select { arguments, columns } = q {
1052                check_arguments(name, &arguments, "", &mut errors);
1053                check_columns(name, &columns, "k:b", &mut errors);
1054            } else {
1055                println!("{} should be select", name);
1056                errors += 1;
1057            }
1058        }
1059
1060        {
1061            let name = "q26";
1062            let src = "SELECT `id` FROM `t1` FORCE INDEX (`hat`)";
1063            let mut issues: Issues<'_> = Issues::new(src);
1064            type_statement(&schema, src, &mut issues, &options);
1065            if issues.is_ok() {
1066                println!("{} should fail", name);
1067                errors += 1;
1068            }
1069        }
1070
1071        {
1072            let name = "q27";
1073            let src = "SELECT `id` FROM `t1` USE INDEX (`hat2`)";
1074            let mut issues: Issues<'_> = Issues::new(src);
1075            let q = type_statement(&schema, src, &mut issues, &options);
1076            check_no_errors(name, src, issues.get(), &mut errors);
1077            if let StatementType::Select { arguments, columns } = q {
1078                check_arguments(name, &arguments, "", &mut errors);
1079                check_columns(name, &columns, "id:i32!", &mut errors);
1080            } else {
1081                println!("{} should be select", name);
1082                errors += 1;
1083            }
1084        }
1085
1086        if errors != 0 {
1087            panic!("{} errors in test", errors);
1088        }
1089    }
1090
1091    #[test]
1092    fn postgresql() {
1093        let schema_src = "
1094        BEGIN;
1095
1096        DO $$ BEGIN
1097            CREATE TYPE my_enum AS ENUM (
1098            'V1',
1099            'V2',
1100            'V3'
1101        );
1102        EXCEPTION
1103            WHEN duplicate_object THEN null;
1104        END $$;
1105
1106        CREATE TABLE IF NOT EXISTS t1 (
1107            id bigint NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY,
1108            path text NOT NULL UNIQUE,
1109            v my_enum NOT NULL,
1110            time timestamptz NOT NULL DEFAULT now(),
1111            old_id bigint,
1112            CONSTRAINT t1__old
1113            FOREIGN KEY(old_id) 
1114            REFERENCES t1(id)
1115            ON DELETE SET NULL
1116        );
1117
1118        CREATE TABLE IF NOT EXISTS t2 (
1119            id bigint NOT NULL PRIMARY KEY
1120        );
1121
1122        DROP INDEX IF EXISTS t2_index;
1123
1124        CREATE INDEX t2_index2 ON t2 (id);
1125
1126        COMMIT;
1127        ";
1128
1129        let options = TypeOptions::new().dialect(SQLDialect::PostgreSQL);
1130        let mut issues = Issues::new(schema_src);
1131        let schema = parse_schemas(schema_src, &mut issues, &options);
1132        let mut errors = 0;
1133        check_no_errors("schema", schema_src, issues.get(), &mut errors);
1134
1135        let options = TypeOptions::new()
1136            .dialect(SQLDialect::PostgreSQL)
1137            .arguments(SQLArguments::Dollar);
1138
1139        {
1140            let name = "q1";
1141            let src =
1142                "INSERT INTO t2 (id) SELECT id FROM t1 WHERE path=$1 ON CONFLICT (id) DO NOTHING RETURNING id";
1143            let mut issues = Issues::new(src);
1144            let q = type_statement(&schema, src, &mut issues, &options);
1145            check_no_errors(name, src, issues.get(), &mut errors);
1146            if let StatementType::Insert {
1147                arguments,
1148                returning,
1149                ..
1150            } = q
1151            {
1152                check_arguments(name, &arguments, "str", &mut errors);
1153                check_columns(name, &returning.expect("Returning"), "id:i64!", &mut errors);
1154            } else {
1155                println!("{} should be select", name);
1156                errors += 1;
1157            }
1158        }
1159
1160        {
1161            let name = "q2";
1162            let src =
1163                "WITH hat AS (DELETE FROM t1 WHERE old_id=42 RETURNING id) INSERT INTO t2 (id) SELECT id FROM hat";
1164            let mut issues = Issues::new(src);
1165            let q = type_statement(&schema, src, &mut issues, &options);
1166            check_no_errors(name, src, issues.get(), &mut errors);
1167
1168            if let StatementType::Insert { arguments, .. } = q {
1169                check_arguments(name, &arguments, "", &mut errors);
1170            } else {
1171                println!("{} should be select {q:?}", name);
1172                errors += 1;
1173            }
1174        }
1175
1176        if errors != 0 {
1177            panic!("{} errors in test", errors);
1178        }
1179    }
1180}