sql_type/
lib.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12#![cfg_attr(not(test), no_std)]
13#![forbid(unsafe_code)]
14
15//! Crate for typing SQL statements.
16//!
17//! ```
18//! use sql_type::{schema::parse_schemas, type_statement, TypeOptions,
19//!     SQLDialect, SQLArguments, StatementType, Issues};
20//! let schemas = "
21//!     CREATE TABLE `events` (
22//!       `id` bigint(20) NOT NULL,
23//!       `user` int(11) NOT NULL,
24//!       `message` text NOT NULL
25//!     );";
26//!
27//! let mut issues = Issues::new(schemas);
28//!
29//! // Compute terse representation of the schemas
30//! let schemas = parse_schemas(schemas,
31//!     &mut issues,
32//!     &TypeOptions::new().dialect(SQLDialect::MariaDB));
33//! assert!(issues.is_ok());
34//!
35//! let sql = "SELECT `id`, `user`, `message` FROM `events` WHERE `id` = ?";
36//! let mut issues = Issues::new(sql);
37//! let stmt = type_statement(&schemas, sql, &mut issues,
38//!     &TypeOptions::new().dialect(SQLDialect::MariaDB).arguments(SQLArguments::QuestionMark));
39//! assert!(issues.is_ok());
40//!
41//! let stmt = match stmt {
42//!     StatementType::Select{columns, arguments} => {
43//!         assert_eq!(columns.len(), 3);
44//!         assert_eq!(arguments.len(), 1);
45//!     }
46//!     _ => panic!("Expected select statement")
47//! };
48//! ```
49
50extern crate alloc;
51
52use alloc::vec::Vec;
53use schema::Schemas;
54use sql_parse::{parse_statement, ParseOptions};
55pub use sql_parse::{Fragment, Issue, Issues, Level};
56
57mod type_;
58mod type_binary_expression;
59mod type_delete;
60mod type_expression;
61mod type_function;
62mod type_insert_replace;
63mod type_reference;
64mod type_select;
65mod type_statement;
66mod type_update;
67mod typer;
68
69pub mod schema;
70pub use type_::{BaseType, FullType, Type};
71pub use type_insert_replace::AutoIncrementId;
72pub use type_select::SelectTypeColumn;
73use typer::Typer;
74
75pub use sql_parse::{SQLArguments, SQLDialect};
76
77/// Options used when typing sql or parsing a schema
78#[derive(Debug, Default, Clone)]
79pub struct TypeOptions {
80    parse_options: ParseOptions,
81    warn_unnamed_column_in_select: bool,
82    warn_duplicate_column_in_select: bool,
83}
84
85impl TypeOptions {
86    /// Produce new default options
87    pub fn new() -> Self {
88        Default::default()
89    }
90
91    /// Change what sql dialect is used
92    pub fn dialect(self, dialect: SQLDialect) -> Self {
93        Self {
94            parse_options: self.parse_options.dialect(dialect),
95            ..self
96        }
97    }
98
99    /// Change how sql arguments are supplied
100    pub fn arguments(self, arguments: SQLArguments) -> Self {
101        Self {
102            parse_options: self.parse_options.arguments(arguments),
103            ..self
104        }
105    }
106
107    /// Should we warn about unquoted identifiers
108    pub fn warn_unquoted_identifiers(self, warn_unquoted_identifiers: bool) -> Self {
109        Self {
110            parse_options: self
111                .parse_options
112                .warn_unquoted_identifiers(warn_unquoted_identifiers),
113            ..self
114        }
115    }
116
117    /// Should we warn about keywords not in ALL CAPS
118    pub fn warn_none_capital_keywords(self, warn_none_capital_keywords: bool) -> Self {
119        Self {
120            parse_options: self
121                .parse_options
122                .warn_none_capital_keywords(warn_none_capital_keywords),
123            ..self
124        }
125    }
126
127    /// Should we warn about unnamed columns in selects
128    pub fn warn_unnamed_column_in_select(self, warn_unnamed_column_in_select: bool) -> Self {
129        Self {
130            warn_unnamed_column_in_select,
131            ..self
132        }
133    }
134
135    /// Should we warn about duplicate columns in selects
136    pub fn warn_duplicate_column_in_select(self, warn_duplicate_column_in_select: bool) -> Self {
137        Self {
138            warn_duplicate_column_in_select,
139            ..self
140        }
141    }
142
143    /// Parse _LIST_ as special expression and type as a list of items
144    pub fn list_hack(self, list_hack: bool) -> Self {
145        Self {
146            parse_options: self.parse_options.list_hack(list_hack),
147            ..self
148        }
149    }
150}
151
152/// Key of argument
153#[derive(Debug, Clone, Hash, PartialEq, Eq)]
154pub enum ArgumentKey<'a> {
155    /// Index of unnamed argument
156    Index(usize),
157    /// Name of named argument
158    Identifier(&'a str),
159}
160
161/// Type information of typed statement
162#[derive(Debug, Clone)]
163pub enum StatementType<'a> {
164    /// The statement was a select statement
165    Select {
166        /// The types and named of the columns return from the select
167        columns: Vec<SelectTypeColumn<'a>>,
168        /// The key and type of arguments to the query
169        arguments: Vec<(ArgumentKey<'a>, FullType<'a>)>,
170    },
171    /// The statement is a delete statement
172    Delete {
173        /// The key and type of arguments to the query
174        arguments: Vec<(ArgumentKey<'a>, FullType<'a>)>,
175        /// If present, the types and names of the columns returned from the delete
176        returning: Option<Vec<SelectTypeColumn<'a>>>,
177    },
178    /// The statement is an insert statement
179    Insert {
180        /// The insert happend in a table with a auto increment id row
181        yield_autoincrement: AutoIncrementId,
182        /// The key and type of arguments to the query
183        arguments: Vec<(ArgumentKey<'a>, FullType<'a>)>,
184        /// If present, the types and names of the columns returned from the insert
185        returning: Option<Vec<SelectTypeColumn<'a>>>,
186    },
187    /// The statement is a update statement
188    Update {
189        /// The key and type of arguments to the query
190        arguments: Vec<(ArgumentKey<'a>, FullType<'a>)>,
191    },
192    /// The statement is a replace statement
193    Replace {
194        /// The key and type of arguments to the query
195        arguments: Vec<(ArgumentKey<'a>, FullType<'a>)>,
196        /// If present, the types and names of the columns returned from the replace
197        returning: Option<Vec<SelectTypeColumn<'a>>>,
198    },
199    /// The query was not valid, errors are preset in issues
200    Invalid,
201}
202
203/// Type an sql statement with respect to a given schema
204pub fn type_statement<'a>(
205    schemas: &'a Schemas<'a>,
206    statement: &'a str,
207    issues: &mut Issues<'a>,
208    options: &TypeOptions,
209) -> StatementType<'a> {
210    if let Some(stmt) = parse_statement(statement, issues, &options.parse_options) {
211        let mut typer = Typer {
212            schemas,
213            issues,
214            reference_types: Vec::new(),
215            arg_types: Default::default(),
216            options,
217            with_schemas: Default::default(),
218        };
219        let t = type_statement::type_statement(&mut typer, &stmt);
220        let arguments = typer.arg_types;
221        match t {
222            type_statement::InnerStatementType::Select(s) => StatementType::Select {
223                columns: s.columns,
224                arguments,
225            },
226            type_statement::InnerStatementType::Delete { returning } => StatementType::Delete {
227                arguments,
228                returning: returning.map(|r| r.columns),
229            },
230            type_statement::InnerStatementType::Insert {
231                auto_increment_id,
232                returning,
233            } => StatementType::Insert {
234                yield_autoincrement: auto_increment_id,
235                arguments,
236                returning: returning.map(|r| r.columns),
237            },
238            type_statement::InnerStatementType::Update => StatementType::Update { arguments },
239            type_statement::InnerStatementType::Replace { returning } => StatementType::Replace {
240                arguments,
241                returning: returning.map(|r| r.columns),
242            },
243            type_statement::InnerStatementType::Invalid => StatementType::Invalid,
244        }
245    } else {
246        StatementType::Invalid
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use alloc::vec::Vec;
253    use codespan_reporting::{
254        diagnostic::{Diagnostic, Label},
255        files::SimpleFiles,
256        term::{
257            self,
258            termcolor::{ColorChoice, StandardStream},
259        },
260    };
261    use sql_parse::{Identifier, Issue, Issues, Level, SQLArguments, SQLDialect};
262
263    use crate::{
264        schema::parse_schemas, type_statement, ArgumentKey, AutoIncrementId, BaseType, FullType,
265        SelectTypeColumn, StatementType, Type, TypeOptions,
266    };
267
268    struct N<'a>(Option<&'a str>);
269    impl<'a> alloc::fmt::Display for N<'a> {
270        fn fmt(&self, f: &mut alloc::fmt::Formatter<'_>) -> alloc::fmt::Result {
271            if let Some(v) = self.0 {
272                v.fmt(f)
273            } else {
274                f.write_str("None")
275            }
276        }
277    }
278
279    struct N2<'a>(Option<Identifier<'a>>);
280    impl<'a> alloc::fmt::Display for N2<'a> {
281        fn fmt(&self, f: &mut alloc::fmt::Formatter<'_>) -> alloc::fmt::Result {
282            if let Some(v) = &self.0 {
283                v.fmt(f)
284            } else {
285                f.write_str("None")
286            }
287        }
288    }
289
290    fn check_no_errors(name: &str, src: &str, issues: &[Issue], errors: &mut usize) {
291        let mut files = SimpleFiles::new();
292        let file_id = files.add(name, &src);
293        let writer = StandardStream::stderr(ColorChoice::Always);
294        let config = codespan_reporting::term::Config::default();
295        for issue in issues {
296            let mut labels = vec![Label::primary(file_id, issue.span.clone())];
297            for fragment in &issue.fragments {
298                labels.push(
299                    Label::secondary(file_id, fragment.span.clone())
300                        .with_message(fragment.message.to_string()),
301                );
302            }
303            let d = match issue.level {
304                Level::Error => {
305                    *errors += 1;
306                    Diagnostic::error()
307                }
308                Level::Warning => Diagnostic::warning(),
309            };
310            let d = d
311                .with_message(issue.message.to_string())
312                .with_labels(labels);
313            term::emit(&mut writer.lock(), &config, &files, &d).unwrap();
314        }
315    }
316
317    fn str_to_type(t: &str) -> FullType<'static> {
318        let (t, not_null) = if let Some(t) = t.strip_suffix('!') {
319            (t, true)
320        } else {
321            (t, false)
322        };
323        let (t, list_hack) = if let Some(v) = t.strip_suffix("[]") {
324            (v, true)
325        } else {
326            (t, false)
327        };
328        let t = match t {
329            "b" => BaseType::Bool.into(),
330            "u8" => Type::U8,
331            "u16" => Type::U16,
332            "u32" => Type::U32,
333            "u64" => Type::U64,
334            "i8" => Type::I8,
335            "i16" => Type::I16,
336            "i32" => Type::I32,
337            "i64" => Type::I64,
338            "f32" => Type::F32,
339            "f64" => Type::F64,
340            "i" => BaseType::Integer.into(),
341            "f" => BaseType::Float.into(),
342            "str" => BaseType::String.into(),
343            "bytes" => BaseType::Bytes.into(),
344            "dt" => BaseType::DateTime.into(),
345            "json" => Type::JSON,
346            "any" => BaseType::Any.into(),
347            _ => panic!("Unknown type {}", t),
348        };
349        let mut t = FullType::new(t, not_null);
350        if list_hack {
351            t.list_hack = true;
352        }
353        t
354    }
355
356    fn check_arguments(
357        name: &str,
358        got: &[(ArgumentKey<'_>, FullType<'_>)],
359        expected: &str,
360        errors: &mut usize,
361    ) {
362        if expected.is_empty() {
363            for (cnt, value) in got.iter().enumerate() {
364                println!("{}: Unexpected argument {} type {:?}", name, cnt, value);
365                *errors += 1;
366            }
367            return;
368        }
369        let mut got2 = Vec::new();
370        let inv = FullType::invalid();
371        for (k, v) in got {
372            match k {
373                ArgumentKey::Index(i) => {
374                    while got2.len() <= *i {
375                        got2.push(&inv);
376                    }
377                    got2[*i] = v;
378                }
379                ArgumentKey::Identifier(k) => {
380                    println!("{}: Got named argument {}", name, k);
381                    *errors += 1;
382                }
383            }
384        }
385        let mut cnt = 0;
386        for (i, t) in expected.split(',').enumerate() {
387            let t = t.trim();
388            let t = str_to_type(t);
389            if let Some(v) = got2.get(i) {
390                if *v != &t {
391                    println!("{}: Expected type {} for argument {} got {}", name, t, i, v);
392                    *errors += 1;
393                }
394            } else {
395                println!("{}: Expected type {} for argument {} got None", name, t, i);
396                *errors += 1;
397            }
398            cnt += 1;
399        }
400        while cnt < got.len() {
401            println!("{}: Unexpected argument {} type {:?}", name, cnt, got[cnt]);
402            cnt += 1;
403            *errors += 1;
404        }
405    }
406
407    fn check_columns(name: &str, got: &[SelectTypeColumn<'_>], expected: &str, errors: &mut usize) {
408        let mut cnt = 0;
409        for (i, t) in expected.split(',').enumerate() {
410            let t = t.trim();
411            let (cname, t) = t.split_once(":").unwrap();
412            let t = str_to_type(t);
413            let cname = if cname.is_empty() { None } else { Some(cname) };
414            if let Some(v) = got.get(i) {
415                if v.name.as_deref() != cname || v.type_ != t {
416                    println!(
417                        "{}: Expected column {} with name {} of type {} got {} of type {}",
418                        name,
419                        i,
420                        N(cname),
421                        t,
422                        N2(v.name.clone()),
423                        v.type_
424                    );
425                    *errors += 1;
426                }
427            } else {
428                println!(
429                    "{}: Expected column {} with name {} of type {} got None",
430                    name,
431                    i,
432                    N(cname),
433                    t
434                );
435                *errors += 1;
436            }
437            cnt += 1;
438        }
439        while cnt < got.len() {
440            println!(
441                "{}: Unexpected column {} with name {} of type {}",
442                name,
443                cnt,
444                N2(got[cnt].name.clone()),
445                got[cnt].type_
446            );
447            cnt += 1;
448            *errors += 1;
449        }
450    }
451
452    #[test]
453    fn mariadb() {
454        let schema_src = "
455
456        DROP TABLE IF EXISTS `t1`;
457        CREATE TABLE `t1` (
458          `id` int(11) NOT NULL,
459          `cbool` tinyint(1) NOT NULL,
460          `cu8` tinyint UNSIGNED NOT NULL,
461          `cu16` smallint UNSIGNED NOT NULL,
462          `cu32` int UNSIGNED NOT NULL,
463          `cu64` bigint UNSIGNED NOT NULL,
464          `ci8` tinyint,
465          `ci16` smallint,
466          `ci32` int,
467          `ci64` bigint,
468          `cbin` binary(16),
469          `ctext` varchar(100) NOT NULL,
470          `cbytes` blob,
471          `cf32` float,
472          `cf64` double,
473          `cu8_plus_one` tinyint UNSIGNED GENERATED ALWAYS AS (
474            `cu8` + 1
475           ) STORED,
476          `status` varchar(10) GENERATED ALWAYS AS (case when `cu8` <> 0 and `cu16` = 0 then 'a' when
477            `cbool` then 'b' when `ci32` = 42 then 'd' when `cu64` = 43 then 'x' when
478            `ci64` = 12 then 'y' else 'z' end) VIRTUAL
479        ) ENGINE=InnoDB DEFAULT CHARSET=utf8;
480
481        ALTER TABLE `t1`
482          MODIFY `id` int(11) NOT NULL AUTO_INCREMENT;
483
484        DROP INDEX IF EXISTS `hat` ON `t1`;
485
486        CREATE INDEX `hat2` ON `t1` (`id`, `cf64`);
487
488        CREATE TABLE `t2` (
489          `id` int(11) NOT NULL AUTO_INCREMENT,
490          `t1_id` int(11) NOT NULL);
491
492        CREATE TABLE `t3` (
493            `id` int(11) NOT NULL AUTO_INCREMENT,
494            `text` TEXT);
495
496        CREATE TABLE `t4` (
497            `id` int(11) NOT NULL AUTO_INCREMENT,
498            `dt` datetime NOT NULL);
499
500        CREATE TABLE `t5` (
501            `id` int(11) NOT NULL AUTO_INCREMENT,
502            `a` int NOT NULL,
503            `b` int,
504            `c` int NOT NULL DEFAULT 42);
505        ";
506
507        let options = TypeOptions::new().dialect(SQLDialect::MariaDB);
508        let mut issues = Issues::new(schema_src);
509        let schema = parse_schemas(schema_src, &mut issues, &options);
510        let mut errors = 0;
511        check_no_errors("schema", schema_src, issues.get(), &mut errors);
512
513        let options = TypeOptions::new()
514            .dialect(SQLDialect::MariaDB)
515            .arguments(SQLArguments::QuestionMark);
516
517        {
518            let name = "q1";
519            let src =
520                "SELECT `id`, `cbool`, `cu8`, `cu8_plus_one`, `cu16`, `cu32`, `cu64`, `ci8`, `ci16`, `ci32`, `ci64`,
521                `ctext`, `cbytes`, `cf32`, `cf64` FROM `t1` WHERE ci8 IS NOT NULL
522                AND `cbool`=? AND `cu8`=? AND `cu16`=? AND `cu32`=? AND `cu64`=?
523                AND `ci8`=? AND `ci16`=? AND `ci32`=? AND `ci64`=?
524                AND `ctext`=? AND `cbytes`=? AND `cf32`=? AND `cf64`=?";
525
526            let mut issues: Issues<'_> = Issues::new(src);
527            let q = type_statement(&schema, src, &mut issues, &options);
528            check_no_errors(name, src, issues.get(), &mut errors);
529            if let StatementType::Select { arguments, columns } = q {
530                check_arguments(
531                    name,
532                    &arguments,
533                    "b,i,i,i,i,i,i,i,i,str,bytes,f,f",
534                    &mut errors,
535                );
536                check_columns(
537                    name,
538                    &columns,
539                    "id:i32!,cbool:b!,cu8:u8!,cu8_plus_one:u8!,cu16:u16!,cu32:u32!,cu64:u64!,
540                    ci8:i8!,ci16:i16!,ci32:i32!,ci64:i64!,ctext:str!,cbytes:bytes!,cf32:f32!,cf64:f64!",
541                    &mut errors,
542                );
543            } else {
544                println!("{} should be select", name);
545                errors += 1;
546            }
547        }
548
549        {
550            let name = "q1.1";
551            let src =
552                "SELECT `id`, `cbool`, `cu8`, `cu16`, `cu32`, `cu64`, `ci8`, `ci16`, `ci32`, `ci64`,
553                `ctext`, `cbytes`, `cf32`, `cf64`, `cbin` FROM `t1` WHERE ci8 IS NOT NULL";
554
555            let mut issues: Issues<'_> = Issues::new(src);
556            let q = type_statement(&schema, src, &mut issues, &options);
557            check_no_errors(name, src, issues.get(), &mut errors);
558            if let StatementType::Select { arguments, columns } = q {
559                check_arguments(name, &arguments, "", &mut errors);
560                check_columns(
561                    name,
562                    &columns,
563                    "id:i32!,cbool:b!,cu8:u8!,cu16:u16!,cu32:u32!,cu64:u64!,
564                    ci8:i8!,ci16:i16,ci32:i32,ci64:i64,ctext:str!,cbytes:bytes,cf32:f32,cf64:f64,cbin:bytes",
565                    &mut errors,
566                );
567            } else {
568                println!("{} should be select", name);
569                errors += 1;
570            }
571        }
572
573        {
574            let name = "q2";
575            let src =
576            "INSERT INTO `t1` (`cbool`, `cu8`, `cu16`, `cu32`, `cu64`, `ci8`, `ci16`, `ci32`, `ci64`,
577            `ctext`, `cbytes`, `cf32`, `cf64`) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
578
579            let mut issues: Issues<'_> = Issues::new(src);
580            let q = type_statement(&schema, src, &mut issues, &options);
581            check_no_errors(name, src, issues.get(), &mut errors);
582            if let StatementType::Insert {
583                arguments,
584                yield_autoincrement,
585                returning,
586            } = q
587            {
588                check_arguments(
589                    name,
590                    &arguments,
591                    "b!,u8!,u16!,u32!,u64!,i8,i16,i32,i64,str!,bytes,f32,f64",
592                    &mut errors,
593                );
594                if yield_autoincrement != AutoIncrementId::Yes {
595                    println!("{} should yield autoincrement", name);
596                    errors += 1;
597                }
598                if returning.is_some() {
599                    println!("{} should not return columns", name);
600                    errors += 1;
601                }
602            } else {
603                println!("{} should be insert", name);
604                errors += 1;
605            }
606        }
607
608        {
609            let name = "q3";
610            let src =
611                "DELETE `t1` FROM `t1`, `t2` WHERE `t1`.`id` = `t2`.`t1_id` AND `t2`.`id` = ?";
612            let mut issues: Issues<'_> = Issues::new(src);
613            let q = type_statement(&schema, src, &mut issues, &options);
614            check_no_errors(name, src, issues.get(), &mut errors);
615            if let StatementType::Delete { arguments, .. } = q {
616                check_arguments(name, &arguments, "i", &mut errors);
617            } else {
618                println!("{} should be delete", name);
619                errors += 1;
620            }
621        }
622
623        {
624            let name = "q4";
625            let src = "INSERT INTO `t2` (`t1_id`) VALUES (?) ON DUPLICATE KEY UPDATE `t1_id`=?";
626            let mut issues: Issues<'_> = Issues::new(src);
627            let q = type_statement(&schema, src, &mut issues, &options);
628            check_no_errors(name, src, issues.get(), &mut errors);
629            if let StatementType::Insert {
630                arguments,
631                yield_autoincrement,
632                returning,
633            } = q
634            {
635                check_arguments(name, &arguments, "i32!,i32!", &mut errors);
636                if yield_autoincrement != AutoIncrementId::Optional {
637                    println!("{} should yield optional auto increment", name);
638                    errors += 1;
639                }
640                if returning.is_some() {
641                    println!("{} should not return columns", name);
642                    errors += 1;
643                }
644            } else {
645                println!("{} should be insert", name);
646                errors += 1;
647            }
648        }
649
650        {
651            let name = "q5";
652            let src = "INSERT IGNORE INTO `t2` SET `t1_id`=?";
653            let mut issues: Issues<'_> = Issues::new(src);
654            let q = type_statement(&schema, src, &mut issues, &options);
655            check_no_errors(name, src, issues.get(), &mut errors);
656            if let StatementType::Insert {
657                arguments,
658                yield_autoincrement,
659                returning,
660            } = q
661            {
662                check_arguments(name, &arguments, "i32!", &mut errors);
663                if yield_autoincrement != AutoIncrementId::Optional {
664                    println!("{} should yield optional auto increment", name);
665                    errors += 1;
666                }
667                if returning.is_some() {
668                    println!("{} should not return columns", name);
669                    errors += 1;
670                }
671            } else {
672                println!("{} should be insert", name);
673                errors += 1;
674            }
675        }
676
677        {
678            let name = "q6";
679            let src = "SELECT IF(`ci32` IS NULL, `cbool`, ?) AS `cc` FROM `t1`";
680            let mut issues: Issues<'_> = Issues::new(src);
681            let q = type_statement(&schema, src, &mut issues, &options);
682            check_no_errors(name, src, issues.get(), &mut errors);
683            if let StatementType::Select { arguments, columns } = q {
684                check_arguments(name, &arguments, "b", &mut errors);
685                check_columns(name, &columns, "cc:b", &mut errors);
686            } else {
687                println!("{} should be select", name);
688                errors += 1;
689            }
690        }
691
692        {
693            let name = "q7";
694            let src = "SELECT FROM_UNIXTIME(CAST(UNIX_TIMESTAMP() AS DOUBLE)) AS `cc` FROM `t1` WHERE `id`=?";
695            let mut issues: Issues<'_> = Issues::new(src);
696            let q = type_statement(&schema, src, &mut issues, &options);
697            check_no_errors(name, src, issues.get(), &mut errors);
698            if let StatementType::Select { arguments, columns } = q {
699                check_arguments(name, &arguments, "i", &mut errors);
700                check_columns(name, &columns, "cc:dt!", &mut errors);
701            } else {
702                println!("{} should be select", name);
703                errors += 1;
704            }
705        }
706
707        {
708            let name = "q8";
709            let src = "REPLACE INTO `t2` SET `id` = ?, `t1_id`=?";
710            let mut issues: Issues<'_> = Issues::new(src);
711            let q = type_statement(&schema, src, &mut issues, &options);
712            check_no_errors(name, src, issues.get(), &mut errors);
713            if let StatementType::Replace {
714                arguments,
715                returning,
716            } = q
717            {
718                check_arguments(name, &arguments, "i32!,i32!", &mut errors);
719                if returning.is_some() {
720                    println!("{} should not return columns", name);
721                    errors += 1;
722                }
723            } else {
724                println!("{} should be replace", name);
725                errors += 1;
726            }
727        }
728
729        {
730            let name = "q9";
731            let src = "INSERT INTO `t2` (`t1_id`) VALUES (32) ON DUPLICATE KEY UPDATE `t1_id` = `t1_id` + VALUES(`t1_id`)";
732            let mut issues: Issues<'_> = Issues::new(src);
733            let q = type_statement(&schema, src, &mut issues, &options);
734            check_no_errors(name, src, issues.get(), &mut errors);
735            if let StatementType::Insert { arguments, .. } = q {
736                check_arguments(name, &arguments, "", &mut errors);
737            } else {
738                println!("{} should be insert", name);
739                errors += 1;
740            }
741        }
742
743        {
744            let name = "q10";
745            let src =
746                "SELECT SUBSTRING_INDEX(`text`, '/', 5) AS `k` FROM `t3` WHERE `text` LIKE '%T%'";
747            let mut issues: Issues<'_> = Issues::new(src);
748            let q = type_statement(&schema, src, &mut issues, &options);
749            check_no_errors(name, src, issues.get(), &mut errors);
750            if let StatementType::Select { arguments, columns } = q {
751                check_arguments(name, &arguments, "", &mut errors);
752                check_columns(name, &columns, "k:str!", &mut errors);
753            } else {
754                println!("{} should be select", name);
755                errors += 1;
756            }
757        }
758
759        {
760            let name = "q11";
761            let src = "SELECT * FROM `t1`, `t2` LEFT JOIN `t3` ON `t3`.`id` = `t1`.`id`";
762            let mut issues: Issues<'_> = Issues::new(src);
763            type_statement(&schema, src, &mut issues, &options);
764            if !issues.get().iter().any(|i| i.level == Level::Error) {
765                println!("{} should be an error", name);
766                errors += 1;
767            }
768        }
769
770        {
771            let name = "q12";
772            let src =
773                "SELECT JSON_REPLACE('{ \"A\": 1, \"B\": [2, 3]}', '$.B[1]', 4, '$.C[3]', 3) AS `k` FROM `t3`";
774            let mut issues: Issues<'_> = Issues::new(src);
775            let q = type_statement(&schema, src, &mut issues, &options);
776            check_no_errors(name, src, issues.get(), &mut errors);
777            if let StatementType::Select { arguments, columns } = q {
778                check_arguments(name, &arguments, "", &mut errors);
779                check_columns(name, &columns, "k:json", &mut errors);
780            } else {
781                println!("{} should be select", name);
782                errors += 1;
783            }
784        }
785
786        {
787            let options = options.clone().list_hack(true);
788            let name = "q13";
789            let src = "SELECT `id` FROM `t1` WHERE `id` IN (_LIST_)";
790            let mut issues: Issues<'_> = Issues::new(src);
791            let q = type_statement(&schema, src, &mut issues, &options);
792            check_no_errors(name, src, issues.get(), &mut errors);
793            if let StatementType::Select { arguments, columns } = q {
794                check_arguments(name, &arguments, "i[]", &mut errors);
795                check_columns(name, &columns, "id:i32!", &mut errors);
796            } else {
797                println!("{} should be select", name);
798                errors += 1;
799            }
800        }
801
802        {
803            let name = "q14";
804            let src = "SELECT CAST(NULL AS CHAR) AS `id`";
805            let mut issues: Issues<'_> = Issues::new(src);
806            let q = type_statement(&schema, src, &mut issues, &options);
807            check_no_errors(name, src, issues.get(), &mut errors);
808            if let StatementType::Select { arguments, columns } = q {
809                check_arguments(name, &arguments, "", &mut errors);
810                check_columns(name, &columns, "id:str", &mut errors);
811            } else {
812                println!("{} should be select", name);
813                errors += 1;
814            }
815        }
816
817        {
818            let name = "q15";
819            let src =
820				"INSERT INTO `t1` (`cbool`, `cu8`, `cu16`, `cu32`, `cu64`, `ci8`, `ci16`, `ci32`, `ci64`,
821            `ctext`, `cbytes`, `cf32`, `cf64`) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
822                 RETURNING `id`, `cbool`, `cu8`, `ctext`, `cf64`";
823            let mut issues: Issues<'_> = Issues::new(src);
824            let q = type_statement(&schema, src, &mut issues, &options);
825            check_no_errors(name, src, issues.get(), &mut errors);
826            if let StatementType::Insert {
827                arguments,
828                yield_autoincrement,
829                returning,
830            } = q
831            {
832                check_arguments(
833                    name,
834                    &arguments,
835                    "b!,u8!,u16!,u32!,u64!,i8,i16,i32,i64,str!,bytes,f32,f64",
836                    &mut errors,
837                );
838                if yield_autoincrement != AutoIncrementId::Yes {
839                    println!("{} should yield autoincrement", name);
840                    errors += 1;
841                }
842                if let Some(returning) = returning {
843                    check_columns(
844                        name,
845                        &returning,
846                        "id:i32!,cbool:b!,cu8:u8!,ctext:str!,cf64:f64",
847                        &mut errors,
848                    );
849                } else {
850                    println!("{} should return columns", name);
851                    errors += 1;
852                }
853            } else {
854                println!("{} should be insert", name);
855                errors += 1;
856            }
857        }
858
859        {
860            let name = "q16";
861            let src = "REPLACE INTO `t2` SET `id` = ?, `t1_id`=? RETURNING `id`";
862            let mut issues: Issues<'_> = Issues::new(src);
863            let q = type_statement(&schema, src, &mut issues, &options);
864            check_no_errors(name, src, issues.get(), &mut errors);
865            if let StatementType::Replace {
866                arguments,
867                returning,
868            } = q
869            {
870                check_arguments(name, &arguments, "i32!,i32!", &mut errors);
871                if let Some(returning) = returning {
872                    check_columns(name, &returning, "id:i32!", &mut errors);
873                } else {
874                    println!("{} should return columns", name);
875                    errors += 1;
876                }
877            } else {
878                println!("{} should be replace", name);
879                errors += 1;
880            }
881        }
882
883        {
884            let name = "q17";
885            let src = "SELECT dt, UNIX_TIMESTAMP(dt) AS t FROM t4";
886            let mut issues: Issues<'_> = Issues::new(src);
887            let q = type_statement(&schema, src, &mut issues, &options);
888            check_no_errors(name, src, issues.get(), &mut errors);
889            if let StatementType::Select { arguments, columns } = q {
890                check_arguments(name, &arguments, "", &mut errors);
891                check_columns(name, &columns, "dt:dt!,t:i64!", &mut errors);
892            } else {
893                println!("{} should be select", name);
894                errors += 1;
895            }
896        }
897
898        {
899            let name = "q17";
900            let src = "SELECT CONCAT(?, \"hat\") AS c";
901            let mut issues: Issues<'_> = Issues::new(src);
902            let q = type_statement(&schema, src, &mut issues, &options);
903            check_no_errors(name, src, issues.get(), &mut errors);
904            if let StatementType::Select { arguments, columns } = q {
905                check_arguments(name, &arguments, "any", &mut errors);
906                check_columns(name, &columns, "c:str", &mut errors);
907            } else {
908                println!("{} should be selsect", name);
909                errors += 1;
910            }
911        }
912
913        {
914            let name = "q18";
915            let src = "SELECT CAST(\"::0\" AS INET6) AS `id`";
916            let mut issues: Issues<'_> = Issues::new(src);
917            let q = type_statement(&schema, src, &mut issues, &options);
918            check_no_errors(name, src, issues.get(), &mut errors);
919            if let StatementType::Select { arguments, columns } = q {
920                check_arguments(name, &arguments, "", &mut errors);
921                check_columns(name, &columns, "id:str!", &mut errors);
922            } else {
923                println!("{} should be select", name);
924                errors += 1;
925            }
926        }
927
928        {
929            let name: &str = "q18";
930            let src = "SELECT SUBSTRING(`cbytes`, 1, 5) AS `k` FROM `t1`";
931            let mut issues: Issues<'_> = Issues::new(src);
932            let q = type_statement(&schema, src, &mut issues, &options);
933            check_no_errors(name, src, issues.get(), &mut errors);
934            if let StatementType::Select { arguments, columns } = q {
935                check_arguments(name, &arguments, "", &mut errors);
936                check_columns(name, &columns, "k:bytes", &mut errors);
937            } else {
938                println!("{} should be select", name);
939                errors += 1;
940            }
941        }
942
943        {
944            let name = "q19";
945            let src = "SELECT SUBSTRING(`ctext`, 1, 5) AS `k` FROM `t1`";
946            let mut issues: Issues<'_> = Issues::new(src);
947            let q = type_statement(&schema, src, &mut issues, &options);
948            check_no_errors(name, src, issues.get(), &mut errors);
949            if let StatementType::Select { arguments, columns } = q {
950                check_arguments(name, &arguments, "", &mut errors);
951                check_columns(name, &columns, "k:str!", &mut errors);
952            } else {
953                println!("{} should be select", name);
954                errors += 1;
955            }
956        }
957
958        {
959            let name = "q19";
960            let src = "SELECT SUBSTRING(`ctext`, 1, 5) AS `k` FROM `t1`";
961            let mut issues: Issues<'_> = Issues::new(src);
962            let q = type_statement(&schema, src, &mut issues, &options);
963            check_no_errors(name, src, issues.get(), &mut errors);
964            if let StatementType::Select { arguments, columns } = q {
965                check_arguments(name, &arguments, "", &mut errors);
966                check_columns(name, &columns, "k:str!", &mut errors);
967            } else {
968                println!("{} should be select", name);
969                errors += 1;
970            }
971        }
972
973        {
974            let name = "q20";
975            let src = "SELECT JSON_QUERY('{ \"A\": 1, \"B\": [2, 3]}', '$.B[1]') AS `k` FROM `t3`";
976            let mut issues: Issues<'_> = Issues::new(src);
977            let q = type_statement(&schema, src, &mut issues, &options);
978            check_no_errors(name, src, issues.get(), &mut errors);
979            if let StatementType::Select { arguments, columns } = q {
980                check_arguments(name, &arguments, "", &mut errors);
981                check_columns(name, &columns, "k:json", &mut errors);
982            } else {
983                println!("{} should be select", name);
984                errors += 1;
985            }
986        }
987
988        {
989            let name = "q21";
990            let src =
991                "SELECT JSON_REMOVE('{ \"A\": 1, \"B\": [2, 3]}', '$.B[1]', '$.C[3]') AS `k` FROM `t3`";
992            let mut issues: Issues<'_> = Issues::new(src);
993            let q = type_statement(&schema, src, &mut issues, &options);
994            check_no_errors(name, src, issues.get(), &mut errors);
995            if let StatementType::Select { arguments, columns } = q {
996                check_arguments(name, &arguments, "", &mut errors);
997                check_columns(name, &columns, "k:json", &mut errors);
998            } else {
999                println!("{} should be select", name);
1000                errors += 1;
1001            }
1002        }
1003
1004        {
1005            let name = "q22";
1006            let src = "SELECT JSON_OVERLAPS('false', 'false') AS `k` FROM `t3`";
1007            let mut issues: Issues<'_> = Issues::new(src);
1008            let q = type_statement(&schema, src, &mut issues, &options);
1009            check_no_errors(name, src, issues.get(), &mut errors);
1010            if let StatementType::Select { arguments, columns } = q {
1011                check_arguments(name, &arguments, "", &mut errors);
1012                check_columns(name, &columns, "k:b!", &mut errors);
1013            } else {
1014                println!("{} should be select", name);
1015                errors += 1;
1016            }
1017        }
1018
1019        {
1020            let name = "q23";
1021            let src = "SELECT JSON_OVERLAPS('false', NULL) AS `k` FROM `t3`";
1022            let mut issues: Issues<'_> = Issues::new(src);
1023            let q = type_statement(&schema, src, &mut issues, &options);
1024            check_no_errors(name, src, issues.get(), &mut errors);
1025            if let StatementType::Select { arguments, columns } = q {
1026                check_arguments(name, &arguments, "", &mut errors);
1027                check_columns(name, &columns, "k:b", &mut errors);
1028            } else {
1029                println!("{} should be select", name);
1030                errors += 1;
1031            }
1032        }
1033
1034        {
1035            let name = "q24";
1036            let src =
1037                "SELECT JSON_CONTAINS('{\"A\": 0, \"B\": [\"x\", \"y\"]}', '\"x\"', '$.B') AS `k` FROM `t3`";
1038            let mut issues: Issues<'_> = Issues::new(src);
1039            let q = type_statement(&schema, src, &mut issues, &options);
1040            check_no_errors(name, src, issues.get(), &mut errors);
1041            if let StatementType::Select { arguments, columns } = q {
1042                check_arguments(name, &arguments, "", &mut errors);
1043                check_columns(name, &columns, "k:b!", &mut errors);
1044            } else {
1045                println!("{} should be select", name);
1046                errors += 1;
1047            }
1048        }
1049
1050        {
1051            let name = "q25";
1052            let src =
1053                "SELECT JSON_CONTAINS('{\"A\": 0, \"B\": [\"x\", \"y\"]}', NULL, '$.A') AS `k` FROM `t3`";
1054            let mut issues: Issues<'_> = Issues::new(src);
1055            let q = type_statement(&schema, src, &mut issues, &options);
1056            check_no_errors(name, src, issues.get(), &mut errors);
1057            if let StatementType::Select { arguments, columns } = q {
1058                check_arguments(name, &arguments, "", &mut errors);
1059                check_columns(name, &columns, "k:b", &mut errors);
1060            } else {
1061                println!("{} should be select", name);
1062                errors += 1;
1063            }
1064        }
1065
1066        {
1067            let name = "q26";
1068            let src = "SELECT `id` FROM `t1` FORCE INDEX (`hat`)";
1069            let mut issues: Issues<'_> = Issues::new(src);
1070            type_statement(&schema, src, &mut issues, &options);
1071            if issues.is_ok() {
1072                println!("{} should fail", name);
1073                errors += 1;
1074            }
1075        }
1076
1077        {
1078            let name = "q27";
1079            let src = "SELECT `id` FROM `t1` USE INDEX (`hat2`)";
1080            let mut issues: Issues<'_> = Issues::new(src);
1081            let q = type_statement(&schema, src, &mut issues, &options);
1082            check_no_errors(name, src, issues.get(), &mut errors);
1083            if let StatementType::Select { arguments, columns } = q {
1084                check_arguments(name, &arguments, "", &mut errors);
1085                check_columns(name, &columns, "id:i32!", &mut errors);
1086            } else {
1087                println!("{} should be select", name);
1088                errors += 1;
1089            }
1090        }
1091
1092        {
1093            let name = "q28";
1094            let src = "INSERT INTO t5 (`a`) VALUES (44)";
1095            check_no_errors(name, src, issues.get(), &mut errors);
1096        }
1097
1098        {
1099            let name = "q29";
1100            let src = "INSERT INTO t5 (`a`, `b`, `c`) VALUES (?, ?)";
1101            let mut issues: Issues<'_> = Issues::new(src);
1102            type_statement(&schema, src, &mut issues, &options);
1103            if issues.is_ok() {
1104                println!("{} should fail", name);
1105                errors += 1;
1106            }
1107        }
1108
1109        {
1110            let name = "q30";
1111            let src = "INSERT INTO t5 (`a`, `b`, `c`) VALUES (?, ?, ?)";
1112            check_no_errors(name, src, issues.get(), &mut errors);
1113        }
1114
1115        {
1116            let name = "q31";
1117            let src = "INSERT INTO t5 (`a`, `b`, `c`) VALUES (?, ?, ?, ?)";
1118            let mut issues: Issues<'_> = Issues::new(src);
1119            type_statement(&schema, src, &mut issues, &options);
1120            if issues.is_ok() {
1121                println!("{} should fail", name);
1122                errors += 1;
1123            }
1124        }
1125
1126        {
1127            let name = "q32";
1128            let src = "INSERT INTO t5 (`b`, `c`) VALUES (44, 45)";
1129            let mut issues: Issues<'_> = Issues::new(src);
1130            type_statement(&schema, src, &mut issues, &options);
1131            if issues.is_ok() {
1132                println!("{} should fail", name);
1133                errors += 1;
1134            }
1135        }
1136
1137        if errors != 0 {
1138            panic!("{} errors in test", errors);
1139        }
1140    }
1141
1142    #[test]
1143    fn postgresql() {
1144        let schema_src = "
1145        BEGIN;
1146
1147        DO $$ BEGIN
1148            CREATE TYPE my_enum AS ENUM (
1149            'V1',
1150            'V2',
1151            'V3'
1152        );
1153        EXCEPTION
1154            WHEN duplicate_object THEN null;
1155        END $$;
1156
1157        CREATE TABLE IF NOT EXISTS t1 (
1158            id bigint NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY,
1159            path text NOT NULL UNIQUE,
1160            v my_enum NOT NULL,
1161            time timestamptz NOT NULL DEFAULT now(),
1162            old_id bigint,
1163            CONSTRAINT t1__old
1164            FOREIGN KEY(old_id) 
1165            REFERENCES t1(id)
1166            ON DELETE SET NULL
1167        );
1168
1169        CREATE TABLE IF NOT EXISTS t2 (
1170            id bigint NOT NULL PRIMARY KEY
1171        );
1172
1173        DROP INDEX IF EXISTS t2_index;
1174
1175        CREATE INDEX t2_index2 ON t2 (id);
1176
1177        COMMIT;
1178        ";
1179
1180        let options = TypeOptions::new().dialect(SQLDialect::PostgreSQL);
1181        let mut issues = Issues::new(schema_src);
1182        let schema = parse_schemas(schema_src, &mut issues, &options);
1183        let mut errors = 0;
1184        check_no_errors("schema", schema_src, issues.get(), &mut errors);
1185
1186        let options = TypeOptions::new()
1187            .dialect(SQLDialect::PostgreSQL)
1188            .arguments(SQLArguments::Dollar);
1189
1190        {
1191            let name = "q1";
1192            let src =
1193                "INSERT INTO t2 (id) SELECT id FROM t1 WHERE path=$1 ON CONFLICT (id) DO NOTHING RETURNING id";
1194            let mut issues = Issues::new(src);
1195            let q = type_statement(&schema, src, &mut issues, &options);
1196            check_no_errors(name, src, issues.get(), &mut errors);
1197            if let StatementType::Insert {
1198                arguments,
1199                returning,
1200                ..
1201            } = q
1202            {
1203                check_arguments(name, &arguments, "str", &mut errors);
1204                check_columns(name, &returning.expect("Returning"), "id:i64!", &mut errors);
1205            } else {
1206                println!("{} should be select", name);
1207                errors += 1;
1208            }
1209        }
1210
1211        {
1212            let name = "q2";
1213            let src =
1214                "WITH hat AS (DELETE FROM t1 WHERE old_id=42 RETURNING id) INSERT INTO t2 (id) SELECT id FROM hat";
1215            let mut issues = Issues::new(src);
1216            let q = type_statement(&schema, src, &mut issues, &options);
1217            check_no_errors(name, src, issues.get(), &mut errors);
1218
1219            if let StatementType::Insert { arguments, .. } = q {
1220                check_arguments(name, &arguments, "", &mut errors);
1221            } else {
1222                println!("{} should be select {q:?}", name);
1223                errors += 1;
1224            }
1225        }
1226
1227
1228        {
1229            let name = "q3";
1230            let src = "INSERT INTO t1 (path) VALUES ('HI')";
1231            let mut issues: Issues<'_> = Issues::new(src);
1232            type_statement(&schema, src, &mut issues, &options);
1233            if issues.is_ok() {
1234                println!("{} should fail", name);
1235                errors += 1;
1236            }
1237        }
1238
1239         {
1240            let name = "q3";
1241            let src = "INSERT INTO t1 (path, v) VALUES ('HI', 'V1')";
1242            let mut issues: Issues<'_> = Issues::new(src);
1243            let q = type_statement(&schema, src, &mut issues, &options);
1244            check_no_errors(name, src, issues.get(), &mut errors);
1245
1246            if let StatementType::Insert { arguments, .. } = q {
1247                check_arguments(name, &arguments, "", &mut errors);
1248            } else {
1249                println!("{} should be insert {q:?}", name);
1250                errors += 1;
1251            }
1252        }
1253
1254        if errors != 0 {
1255            panic!("{} errors in test", errors);
1256        }
1257    }
1258}