sql_type/
lib.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12#![cfg_attr(not(test), no_std)]
13#![forbid(unsafe_code)]
14
15//! Crate for typing SQL statements.
16//!
17//! ```
18//! use sql_type::{schema::parse_schemas, type_statement, TypeOptions,
19//!     SQLDialect, SQLArguments, StatementType, Issues};
20//! let schemas = "
21//!     CREATE TABLE `events` (
22//!       `id` bigint(20) NOT NULL,
23//!       `user` int(11) NOT NULL,
24//!       `message` text NOT NULL
25//!     );";
26//!
27//! let mut issues = Issues::new(schemas);
28//!
29//! // Compute terse representation of the schemas
30//! let schemas = parse_schemas(schemas,
31//!     &mut issues,
32//!     &TypeOptions::new().dialect(SQLDialect::MariaDB));
33//! assert!(issues.is_ok());
34//!
35//! let sql = "SELECT `id`, `user`, `message` FROM `events` WHERE `id` = ?";
36//! let mut issues = Issues::new(sql);
37//! let stmt = type_statement(&schemas, sql, &mut issues,
38//!     &TypeOptions::new().dialect(SQLDialect::MariaDB).arguments(SQLArguments::QuestionMark));
39//! assert!(issues.is_ok());
40//!
41//! let stmt = match stmt {
42//!     StatementType::Select{columns, arguments} => {
43//!         assert_eq!(columns.len(), 3);
44//!         assert_eq!(arguments.len(), 1);
45//!     }
46//!     _ => panic!("Expected select statement")
47//! };
48//! ```
49
50extern crate alloc;
51
52use alloc::vec::Vec;
53use schema::Schemas;
54pub use sql_parse::{Fragment, Issue, Issues, Level};
55use sql_parse::{ParseOptions, parse_statement};
56
57mod type_;
58mod type_binary_expression;
59mod type_delete;
60mod type_expression;
61mod type_function;
62mod type_insert_replace;
63mod type_reference;
64mod type_select;
65mod type_statement;
66mod type_update;
67mod typer;
68
69pub mod schema;
70pub use type_::{BaseType, FullType, Type};
71pub use type_insert_replace::AutoIncrementId;
72pub use type_select::SelectTypeColumn;
73use typer::Typer;
74
75pub use sql_parse::{SQLArguments, SQLDialect};
76
77/// Options used when typing sql or parsing a schema
78#[derive(Debug, Default, Clone)]
79pub struct TypeOptions {
80    parse_options: ParseOptions,
81    warn_unnamed_column_in_select: bool,
82    warn_duplicate_column_in_select: bool,
83}
84
85impl TypeOptions {
86    /// Produce new default options
87    pub fn new() -> Self {
88        Default::default()
89    }
90
91    /// Change what sql dialect is used
92    pub fn dialect(self, dialect: SQLDialect) -> Self {
93        Self {
94            parse_options: self.parse_options.dialect(dialect),
95            ..self
96        }
97    }
98
99    /// Change how sql arguments are supplied
100    pub fn arguments(self, arguments: SQLArguments) -> Self {
101        Self {
102            parse_options: self.parse_options.arguments(arguments),
103            ..self
104        }
105    }
106
107    /// Should we warn about unquoted identifiers
108    pub fn warn_unquoted_identifiers(self, warn_unquoted_identifiers: bool) -> Self {
109        Self {
110            parse_options: self
111                .parse_options
112                .warn_unquoted_identifiers(warn_unquoted_identifiers),
113            ..self
114        }
115    }
116
117    /// Should we warn about keywords not in ALL CAPS
118    pub fn warn_none_capital_keywords(self, warn_none_capital_keywords: bool) -> Self {
119        Self {
120            parse_options: self
121                .parse_options
122                .warn_none_capital_keywords(warn_none_capital_keywords),
123            ..self
124        }
125    }
126
127    /// Should we warn about unnamed columns in selects
128    pub fn warn_unnamed_column_in_select(self, warn_unnamed_column_in_select: bool) -> Self {
129        Self {
130            warn_unnamed_column_in_select,
131            ..self
132        }
133    }
134
135    /// Should we warn about duplicate columns in selects
136    pub fn warn_duplicate_column_in_select(self, warn_duplicate_column_in_select: bool) -> Self {
137        Self {
138            warn_duplicate_column_in_select,
139            ..self
140        }
141    }
142
143    /// Parse _LIST_ as special expression and type as a list of items
144    pub fn list_hack(self, list_hack: bool) -> Self {
145        Self {
146            parse_options: self.parse_options.list_hack(list_hack),
147            ..self
148        }
149    }
150}
151
152/// Key of argument
153#[derive(Debug, Clone, Hash, PartialEq, Eq)]
154pub enum ArgumentKey<'a> {
155    /// Index of unnamed argument
156    Index(usize),
157    /// Name of named argument
158    Identifier(&'a str),
159}
160
161/// Type information of typed statement
162#[derive(Debug, Clone)]
163pub enum StatementType<'a> {
164    /// The statement was a select statement
165    Select {
166        /// The types and named of the columns return from the select
167        columns: Vec<SelectTypeColumn<'a>>,
168        /// The key and type of arguments to the query
169        arguments: Vec<(ArgumentKey<'a>, FullType<'a>)>,
170    },
171    /// The statement is a delete statement
172    Delete {
173        /// The key and type of arguments to the query
174        arguments: Vec<(ArgumentKey<'a>, FullType<'a>)>,
175        /// If present, the types and names of the columns returned from the delete
176        returning: Option<Vec<SelectTypeColumn<'a>>>,
177    },
178    /// The statement is an insert statement
179    Insert {
180        /// The insert happend in a table with a auto increment id row
181        yield_autoincrement: AutoIncrementId,
182        /// The key and type of arguments to the query
183        arguments: Vec<(ArgumentKey<'a>, FullType<'a>)>,
184        /// If present, the types and names of the columns returned from the insert
185        returning: Option<Vec<SelectTypeColumn<'a>>>,
186    },
187    /// The statement is a update statement
188    Update {
189        /// The key and type of arguments to the query
190        arguments: Vec<(ArgumentKey<'a>, FullType<'a>)>,
191        /// If present, the types and names of the columns returned from the insert
192        returning: Option<Vec<SelectTypeColumn<'a>>>,
193    },
194    /// The statement is a replace statement
195    Replace {
196        /// The key and type of arguments to the query
197        arguments: Vec<(ArgumentKey<'a>, FullType<'a>)>,
198        /// If present, the types and names of the columns returned from the replace
199        returning: Option<Vec<SelectTypeColumn<'a>>>,
200    },
201    /// The query was not valid, errors are preset in issues
202    Invalid,
203}
204
205/// Type an sql statement with respect to a given schema
206pub fn type_statement<'a>(
207    schemas: &'a Schemas<'a>,
208    statement: &'a str,
209    issues: &mut Issues<'a>,
210    options: &TypeOptions,
211) -> StatementType<'a> {
212    if let Some(stmt) = parse_statement(statement, issues, &options.parse_options) {
213        let mut typer = Typer {
214            schemas,
215            issues,
216            reference_types: Vec::new(),
217            arg_types: Default::default(),
218            options,
219            with_schemas: Default::default(),
220        };
221        let t = type_statement::type_statement(&mut typer, &stmt);
222        let arguments = typer.arg_types;
223        match t {
224            type_statement::InnerStatementType::Select(s) => StatementType::Select {
225                columns: s.columns,
226                arguments,
227            },
228            type_statement::InnerStatementType::Delete { returning } => StatementType::Delete {
229                arguments,
230                returning: returning.map(|r| r.columns),
231            },
232            type_statement::InnerStatementType::Insert {
233                auto_increment_id,
234                returning,
235            } => StatementType::Insert {
236                yield_autoincrement: auto_increment_id,
237                arguments,
238                returning: returning.map(|r| r.columns),
239            },
240            type_statement::InnerStatementType::Update { returning } => StatementType::Update {
241                arguments,
242                returning: returning.map(|r| r.columns),
243            },
244            type_statement::InnerStatementType::Replace { returning } => StatementType::Replace {
245                arguments,
246                returning: returning.map(|r| r.columns),
247            },
248            type_statement::InnerStatementType::Invalid => StatementType::Invalid,
249        }
250    } else {
251        StatementType::Invalid
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use alloc::vec::Vec;
258    use codespan_reporting::{
259        diagnostic::{Diagnostic, Label},
260        files::SimpleFiles,
261        term::{
262            self,
263            termcolor::{ColorChoice, StandardStream},
264        },
265    };
266    use sql_parse::{Identifier, Issue, Issues, Level, SQLArguments, SQLDialect};
267
268    use crate::{
269        ArgumentKey, AutoIncrementId, BaseType, FullType, SelectTypeColumn, StatementType, Type,
270        TypeOptions, schema::parse_schemas, type_statement,
271    };
272
273    struct N<'a>(Option<&'a str>);
274    impl<'a> alloc::fmt::Display for N<'a> {
275        fn fmt(&self, f: &mut alloc::fmt::Formatter<'_>) -> alloc::fmt::Result {
276            if let Some(v) = self.0 {
277                v.fmt(f)
278            } else {
279                f.write_str("None")
280            }
281        }
282    }
283
284    struct N2<'a>(Option<Identifier<'a>>);
285    impl<'a> alloc::fmt::Display for N2<'a> {
286        fn fmt(&self, f: &mut alloc::fmt::Formatter<'_>) -> alloc::fmt::Result {
287            if let Some(v) = &self.0 {
288                v.fmt(f)
289            } else {
290                f.write_str("None")
291            }
292        }
293    }
294
295    fn check_no_errors(name: &str, src: &str, issues: &[Issue], errors: &mut usize) {
296        let mut files = SimpleFiles::new();
297        let file_id = files.add(name, &src);
298        let writer = StandardStream::stderr(ColorChoice::Always);
299        let config = codespan_reporting::term::Config::default();
300        for issue in issues {
301            let mut labels = vec![Label::primary(file_id, issue.span.clone())];
302            for fragment in &issue.fragments {
303                labels.push(
304                    Label::secondary(file_id, fragment.span.clone())
305                        .with_message(fragment.message.to_string()),
306                );
307            }
308            let d = match issue.level {
309                Level::Error => {
310                    *errors += 1;
311                    Diagnostic::error()
312                }
313                Level::Warning => Diagnostic::warning(),
314            };
315            let d = d
316                .with_message(issue.message.to_string())
317                .with_labels(labels);
318            term::emit(&mut writer.lock(), &config, &files, &d).unwrap();
319        }
320    }
321
322    fn str_to_type(t: &str) -> FullType<'static> {
323        let (t, not_null) = if let Some(t) = t.strip_suffix('!') {
324            (t, true)
325        } else {
326            (t, false)
327        };
328        let (t, list_hack) = if let Some(v) = t.strip_suffix("[]") {
329            (v, true)
330        } else {
331            (t, false)
332        };
333        let t = match t {
334            "b" => BaseType::Bool.into(),
335            "u8" => Type::U8,
336            "u16" => Type::U16,
337            "u32" => Type::U32,
338            "u64" => Type::U64,
339            "i8" => Type::I8,
340            "i16" => Type::I16,
341            "i32" => Type::I32,
342            "i64" => Type::I64,
343            "f32" => Type::F32,
344            "f64" => Type::F64,
345            "i" => BaseType::Integer.into(),
346            "f" => BaseType::Float.into(),
347            "str" => BaseType::String.into(),
348            "bytes" => BaseType::Bytes.into(),
349            "dt" => BaseType::DateTime.into(),
350            "json" => Type::JSON,
351            "any" => BaseType::Any.into(),
352            _ => panic!("Unknown type {t}"),
353        };
354        let mut t = FullType::new(t, not_null);
355        if list_hack {
356            t.list_hack = true;
357        }
358        t
359    }
360
361    fn check_arguments(
362        name: &str,
363        got: &[(ArgumentKey<'_>, FullType<'_>)],
364        expected: &str,
365        errors: &mut usize,
366    ) {
367        if expected.is_empty() {
368            for (cnt, value) in got.iter().enumerate() {
369                println!("{name}: Unexpected argument {cnt} type {value:?}");
370                *errors += 1;
371            }
372            return;
373        }
374        let mut got2 = Vec::new();
375        let inv = FullType::invalid();
376        for (k, v) in got {
377            match k {
378                ArgumentKey::Index(i) => {
379                    while got2.len() <= *i {
380                        got2.push(&inv);
381                    }
382                    got2[*i] = v;
383                }
384                ArgumentKey::Identifier(k) => {
385                    println!("{name}: Got named argument {k}");
386                    *errors += 1;
387                }
388            }
389        }
390        let mut cnt = 0;
391        for (i, t) in expected.split(',').enumerate() {
392            let t = t.trim();
393            let t = str_to_type(t);
394            if let Some(v) = got2.get(i) {
395                if *v != &t {
396                    println!("{name}: Expected type {t} for argument {i} got {v}");
397                    *errors += 1;
398                }
399            } else {
400                println!("{name}: Expected type {t} for argument {i} got None");
401                *errors += 1;
402            }
403            cnt += 1;
404        }
405        while cnt < got.len() {
406            println!("{}: Unexpected argument {} type {:?}", name, cnt, got[cnt]);
407            cnt += 1;
408            *errors += 1;
409        }
410    }
411
412    fn check_columns(name: &str, got: &[SelectTypeColumn<'_>], expected: &str, errors: &mut usize) {
413        let mut cnt = 0;
414        for (i, t) in expected.split(',').enumerate() {
415            let t = t.trim();
416            let (cname, t) = t.split_once(":").unwrap();
417            let t = str_to_type(t);
418            let cname = if cname.is_empty() { None } else { Some(cname) };
419            if let Some(v) = got.get(i) {
420                if v.name.as_deref() != cname || v.type_ != t {
421                    println!(
422                        "{}: Expected column {} with name {} of type {} got {} of type {}",
423                        name,
424                        i,
425                        N(cname),
426                        t,
427                        N2(v.name.clone()),
428                        v.type_
429                    );
430                    *errors += 1;
431                }
432            } else {
433                println!(
434                    "{}: Expected column {} with name {} of type {} got None",
435                    name,
436                    i,
437                    N(cname),
438                    t
439                );
440                *errors += 1;
441            }
442            cnt += 1;
443        }
444        while cnt < got.len() {
445            println!(
446                "{}: Unexpected column {} with name {} of type {}",
447                name,
448                cnt,
449                N2(got[cnt].name.clone()),
450                got[cnt].type_
451            );
452            cnt += 1;
453            *errors += 1;
454        }
455    }
456
457    #[test]
458    fn mariadb() {
459        let schema_src = "
460
461        DROP TABLE IF EXISTS `t1`;
462        CREATE TABLE `t1` (
463          `id` int(11) NOT NULL,
464          `cbool` tinyint(1) NOT NULL,
465          `cu8` tinyint UNSIGNED NOT NULL,
466          `cu16` smallint UNSIGNED NOT NULL,
467          `cu32` int UNSIGNED NOT NULL,
468          `cu64` bigint UNSIGNED NOT NULL,
469          `ci8` tinyint,
470          `ci16` smallint,
471          `ci32` int,
472          `ci64` bigint,
473          `cbin` binary(16),
474          `ctext` varchar(100) NOT NULL,
475          `cbytes` blob,
476          `cf32` float,
477          `cf64` double,
478          `cu8_plus_one` tinyint UNSIGNED GENERATED ALWAYS AS (
479            `cu8` + 1
480           ) STORED,
481          `status` varchar(10) GENERATED ALWAYS AS (case when `cu8` <> 0 and `cu16` = 0 then 'a' when
482            `cbool` then 'b' when `ci32` = 42 then 'd' when `cu64` = 43 then 'x' when
483            `ci64` = 12 then 'y' else 'z' end) VIRTUAL
484        ) ENGINE=InnoDB DEFAULT CHARSET=utf8;
485
486        ALTER TABLE `t1`
487          MODIFY `id` int(11) NOT NULL AUTO_INCREMENT;
488
489        DROP INDEX IF EXISTS `hat` ON `t1`;
490
491        CREATE INDEX `hat2` ON `t1` (`id`, `cf64`);
492
493        CREATE TABLE `t2` (
494          `id` int(11) NOT NULL AUTO_INCREMENT,
495          `t1_id` int(11) NOT NULL);
496
497        CREATE TABLE `t3` (
498            `id` int(11) NOT NULL AUTO_INCREMENT,
499            `text` TEXT);
500
501        CREATE TABLE `t4` (
502            `id` int(11) NOT NULL AUTO_INCREMENT,
503            `dt` datetime NOT NULL);
504
505        CREATE TABLE `t5` (
506            `id` int(11) NOT NULL AUTO_INCREMENT,
507            `a` int NOT NULL,
508            `b` int,
509            `c` int NOT NULL DEFAULT 42);
510        ";
511
512        let options = TypeOptions::new().dialect(SQLDialect::MariaDB);
513        let mut issues = Issues::new(schema_src);
514        let schema = parse_schemas(schema_src, &mut issues, &options);
515        let mut errors = 0;
516        check_no_errors("schema", schema_src, issues.get(), &mut errors);
517
518        let options = TypeOptions::new()
519            .dialect(SQLDialect::MariaDB)
520            .arguments(SQLArguments::QuestionMark);
521
522        {
523            let name = "q1";
524            let src =
525                "SELECT `id`, `cbool`, `cu8`, `cu8_plus_one`, `cu16`, `cu32`, `cu64`, `ci8`, `ci16`, `ci32`, `ci64`,
526                `ctext`, `cbytes`, `cf32`, `cf64` FROM `t1` WHERE ci8 IS NOT NULL
527                AND `cbool`=? AND `cu8`=? AND `cu16`=? AND `cu32`=? AND `cu64`=?
528                AND `ci8`=? AND `ci16`=? AND `ci32`=? AND `ci64`=?
529                AND `ctext`=? AND `cbytes`=? AND `cf32`=? AND `cf64`=?";
530
531            let mut issues: Issues<'_> = Issues::new(src);
532            let q = type_statement(&schema, src, &mut issues, &options);
533            check_no_errors(name, src, issues.get(), &mut errors);
534            if let StatementType::Select { arguments, columns } = q {
535                check_arguments(
536                    name,
537                    &arguments,
538                    "b,i,i,i,i,i,i,i,i,str,bytes,f,f",
539                    &mut errors,
540                );
541                check_columns(
542                    name,
543                    &columns,
544                    "id:i32!,cbool:b!,cu8:u8!,cu8_plus_one:u8!,cu16:u16!,cu32:u32!,cu64:u64!,
545                    ci8:i8!,ci16:i16!,ci32:i32!,ci64:i64!,ctext:str!,cbytes:bytes!,cf32:f32!,cf64:f64!",
546                    &mut errors,
547                );
548            } else {
549                println!("{name} should be select");
550                errors += 1;
551            }
552        }
553
554        {
555            let name = "q1.1";
556            let src =
557                "SELECT `id`, `cbool`, `cu8`, `cu16`, `cu32`, `cu64`, `ci8`, `ci16`, `ci32`, `ci64`,
558                `ctext`, `cbytes`, `cf32`, `cf64`, `cbin` FROM `t1` WHERE ci8 IS NOT NULL";
559
560            let mut issues: Issues<'_> = Issues::new(src);
561            let q = type_statement(&schema, src, &mut issues, &options);
562            check_no_errors(name, src, issues.get(), &mut errors);
563            if let StatementType::Select { arguments, columns } = q {
564                check_arguments(name, &arguments, "", &mut errors);
565                check_columns(
566                    name,
567                    &columns,
568                    "id:i32!,cbool:b!,cu8:u8!,cu16:u16!,cu32:u32!,cu64:u64!,
569                    ci8:i8!,ci16:i16,ci32:i32,ci64:i64,ctext:str!,cbytes:bytes,cf32:f32,cf64:f64,cbin:bytes",
570                    &mut errors,
571                );
572            } else {
573                println!("{name} should be select");
574                errors += 1;
575            }
576        }
577
578        {
579            let name = "q2";
580            let src =
581            "INSERT INTO `t1` (`cbool`, `cu8`, `cu16`, `cu32`, `cu64`, `ci8`, `ci16`, `ci32`, `ci64`,
582            `ctext`, `cbytes`, `cf32`, `cf64`) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
583
584            let mut issues: Issues<'_> = Issues::new(src);
585            let q = type_statement(&schema, src, &mut issues, &options);
586            check_no_errors(name, src, issues.get(), &mut errors);
587            if let StatementType::Insert {
588                arguments,
589                yield_autoincrement,
590                returning,
591            } = q
592            {
593                check_arguments(
594                    name,
595                    &arguments,
596                    "b!,u8!,u16!,u32!,u64!,i8,i16,i32,i64,str!,bytes,f32,f64",
597                    &mut errors,
598                );
599                if yield_autoincrement != AutoIncrementId::Yes {
600                    println!("{name} should yield autoincrement");
601                    errors += 1;
602                }
603                if returning.is_some() {
604                    println!("{name} should not return columns");
605                    errors += 1;
606                }
607            } else {
608                println!("{name} should be insert");
609                errors += 1;
610            }
611        }
612
613        {
614            let name = "q3";
615            let src =
616                "DELETE `t1` FROM `t1`, `t2` WHERE `t1`.`id` = `t2`.`t1_id` AND `t2`.`id` = ?";
617            let mut issues: Issues<'_> = Issues::new(src);
618            let q = type_statement(&schema, src, &mut issues, &options);
619            check_no_errors(name, src, issues.get(), &mut errors);
620            if let StatementType::Delete { arguments, .. } = q {
621                check_arguments(name, &arguments, "i", &mut errors);
622            } else {
623                println!("{name} should be delete");
624                errors += 1;
625            }
626        }
627
628        {
629            let name = "q4";
630            let src = "INSERT INTO `t2` (`t1_id`) VALUES (?) ON DUPLICATE KEY UPDATE `t1_id`=?";
631            let mut issues: Issues<'_> = Issues::new(src);
632            let q = type_statement(&schema, src, &mut issues, &options);
633            check_no_errors(name, src, issues.get(), &mut errors);
634            if let StatementType::Insert {
635                arguments,
636                yield_autoincrement,
637                returning,
638            } = q
639            {
640                check_arguments(name, &arguments, "i32!,i32!", &mut errors);
641                if yield_autoincrement != AutoIncrementId::Optional {
642                    println!("{name} should yield optional auto increment");
643                    errors += 1;
644                }
645                if returning.is_some() {
646                    println!("{name} should not return columns");
647                    errors += 1;
648                }
649            } else {
650                println!("{name} should be insert");
651                errors += 1;
652            }
653        }
654
655        {
656            let name = "q5";
657            let src = "INSERT IGNORE INTO `t2` SET `t1_id`=?";
658            let mut issues: Issues<'_> = Issues::new(src);
659            let q = type_statement(&schema, src, &mut issues, &options);
660            check_no_errors(name, src, issues.get(), &mut errors);
661            if let StatementType::Insert {
662                arguments,
663                yield_autoincrement,
664                returning,
665            } = q
666            {
667                check_arguments(name, &arguments, "i32!", &mut errors);
668                if yield_autoincrement != AutoIncrementId::Optional {
669                    println!("{name} should yield optional auto increment");
670                    errors += 1;
671                }
672                if returning.is_some() {
673                    println!("{name} should not return columns");
674                    errors += 1;
675                }
676            } else {
677                println!("{name} should be insert");
678                errors += 1;
679            }
680        }
681
682        {
683            let name = "q6";
684            let src = "SELECT IF(`ci32` IS NULL, `cbool`, ?) AS `cc` FROM `t1`";
685            let mut issues: Issues<'_> = Issues::new(src);
686            let q = type_statement(&schema, src, &mut issues, &options);
687            check_no_errors(name, src, issues.get(), &mut errors);
688            if let StatementType::Select { arguments, columns } = q {
689                check_arguments(name, &arguments, "b", &mut errors);
690                check_columns(name, &columns, "cc:b", &mut errors);
691            } else {
692                println!("{name} should be select");
693                errors += 1;
694            }
695        }
696
697        {
698            let name = "q7";
699            let src = "SELECT FROM_UNIXTIME(CAST(UNIX_TIMESTAMP() AS DOUBLE)) AS `cc` FROM `t1` WHERE `id`=?";
700            let mut issues: Issues<'_> = Issues::new(src);
701            let q = type_statement(&schema, src, &mut issues, &options);
702            check_no_errors(name, src, issues.get(), &mut errors);
703            if let StatementType::Select { arguments, columns } = q {
704                check_arguments(name, &arguments, "i", &mut errors);
705                check_columns(name, &columns, "cc:dt!", &mut errors);
706            } else {
707                println!("{name} should be select");
708                errors += 1;
709            }
710        }
711
712        {
713            let name = "q8";
714            let src = "REPLACE INTO `t2` SET `id` = ?, `t1_id`=?";
715            let mut issues: Issues<'_> = Issues::new(src);
716            let q = type_statement(&schema, src, &mut issues, &options);
717            check_no_errors(name, src, issues.get(), &mut errors);
718            if let StatementType::Replace {
719                arguments,
720                returning,
721            } = q
722            {
723                check_arguments(name, &arguments, "i32!,i32!", &mut errors);
724                if returning.is_some() {
725                    println!("{name} should not return columns");
726                    errors += 1;
727                }
728            } else {
729                println!("{name} should be replace");
730                errors += 1;
731            }
732        }
733
734        {
735            let name = "q9";
736            let src = "INSERT INTO `t2` (`t1_id`) VALUES (32) ON DUPLICATE KEY UPDATE `t1_id` = `t1_id` + VALUES(`t1_id`)";
737            let mut issues: Issues<'_> = Issues::new(src);
738            let q = type_statement(&schema, src, &mut issues, &options);
739            check_no_errors(name, src, issues.get(), &mut errors);
740            if let StatementType::Insert { arguments, .. } = q {
741                check_arguments(name, &arguments, "", &mut errors);
742            } else {
743                println!("{name} should be insert");
744                errors += 1;
745            }
746        }
747
748        {
749            let name = "q10";
750            let src =
751                "SELECT SUBSTRING_INDEX(`text`, '/', 5) AS `k` FROM `t3` WHERE `text` LIKE '%T%'";
752            let mut issues: Issues<'_> = Issues::new(src);
753            let q = type_statement(&schema, src, &mut issues, &options);
754            check_no_errors(name, src, issues.get(), &mut errors);
755            if let StatementType::Select { arguments, columns } = q {
756                check_arguments(name, &arguments, "", &mut errors);
757                check_columns(name, &columns, "k:str!", &mut errors);
758            } else {
759                println!("{name} should be select");
760                errors += 1;
761            }
762        }
763
764        {
765            let name = "q11";
766            let src = "SELECT * FROM `t1`, `t2` LEFT JOIN `t3` ON `t3`.`id` = `t1`.`id`";
767            let mut issues: Issues<'_> = Issues::new(src);
768            type_statement(&schema, src, &mut issues, &options);
769            if !issues.get().iter().any(|i| i.level == Level::Error) {
770                println!("{name} should be an error");
771                errors += 1;
772            }
773        }
774
775        {
776            let name = "q12";
777            let src = "SELECT JSON_REPLACE('{ \"A\": 1, \"B\": [2, 3]}', '$.B[1]', 4, '$.C[3]', 3) AS `k` FROM `t3`";
778            let mut issues: Issues<'_> = Issues::new(src);
779            let q = type_statement(&schema, src, &mut issues, &options);
780            check_no_errors(name, src, issues.get(), &mut errors);
781            if let StatementType::Select { arguments, columns } = q {
782                check_arguments(name, &arguments, "", &mut errors);
783                check_columns(name, &columns, "k:json", &mut errors);
784            } else {
785                println!("{name} should be select");
786                errors += 1;
787            }
788        }
789
790        {
791            let options = options.clone().list_hack(true);
792            let name = "q13";
793            let src = "SELECT `id` FROM `t1` WHERE `id` IN (_LIST_)";
794            let mut issues: Issues<'_> = Issues::new(src);
795            let q = type_statement(&schema, src, &mut issues, &options);
796            check_no_errors(name, src, issues.get(), &mut errors);
797            if let StatementType::Select { arguments, columns } = q {
798                check_arguments(name, &arguments, "i[]", &mut errors);
799                check_columns(name, &columns, "id:i32!", &mut errors);
800            } else {
801                println!("{name} should be select");
802                errors += 1;
803            }
804        }
805
806        {
807            let name = "q14";
808            let src = "SELECT CAST(NULL AS CHAR) AS `id`";
809            let mut issues: Issues<'_> = Issues::new(src);
810            let q = type_statement(&schema, src, &mut issues, &options);
811            check_no_errors(name, src, issues.get(), &mut errors);
812            if let StatementType::Select { arguments, columns } = q {
813                check_arguments(name, &arguments, "", &mut errors);
814                check_columns(name, &columns, "id:str", &mut errors);
815            } else {
816                println!("{name} should be select");
817                errors += 1;
818            }
819        }
820
821        {
822            let name = "q15";
823            let src =
824				"INSERT INTO `t1` (`cbool`, `cu8`, `cu16`, `cu32`, `cu64`, `ci8`, `ci16`, `ci32`, `ci64`,
825            `ctext`, `cbytes`, `cf32`, `cf64`) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
826                 RETURNING `id`, `cbool`, `cu8`, `ctext`, `cf64`";
827            let mut issues: Issues<'_> = Issues::new(src);
828            let q = type_statement(&schema, src, &mut issues, &options);
829            check_no_errors(name, src, issues.get(), &mut errors);
830            if let StatementType::Insert {
831                arguments,
832                yield_autoincrement,
833                returning,
834            } = q
835            {
836                check_arguments(
837                    name,
838                    &arguments,
839                    "b!,u8!,u16!,u32!,u64!,i8,i16,i32,i64,str!,bytes,f32,f64",
840                    &mut errors,
841                );
842                if yield_autoincrement != AutoIncrementId::Yes {
843                    println!("{name} should yield autoincrement");
844                    errors += 1;
845                }
846                if let Some(returning) = returning {
847                    check_columns(
848                        name,
849                        &returning,
850                        "id:i32!,cbool:b!,cu8:u8!,ctext:str!,cf64:f64",
851                        &mut errors,
852                    );
853                } else {
854                    println!("{name} should return columns");
855                    errors += 1;
856                }
857            } else {
858                println!("{name} should be insert");
859                errors += 1;
860            }
861        }
862
863        {
864            let name = "q16";
865            let src = "REPLACE INTO `t2` SET `id` = ?, `t1_id`=? RETURNING `id`";
866            let mut issues: Issues<'_> = Issues::new(src);
867            let q = type_statement(&schema, src, &mut issues, &options);
868            check_no_errors(name, src, issues.get(), &mut errors);
869            if let StatementType::Replace {
870                arguments,
871                returning,
872            } = q
873            {
874                check_arguments(name, &arguments, "i32!,i32!", &mut errors);
875                if let Some(returning) = returning {
876                    check_columns(name, &returning, "id:i32!", &mut errors);
877                } else {
878                    println!("{name} should return columns");
879                    errors += 1;
880                }
881            } else {
882                println!("{name} should be replace");
883                errors += 1;
884            }
885        }
886
887        {
888            let name = "q17";
889            let src = "SELECT dt, UNIX_TIMESTAMP(dt) AS t FROM t4";
890            let mut issues: Issues<'_> = Issues::new(src);
891            let q = type_statement(&schema, src, &mut issues, &options);
892            check_no_errors(name, src, issues.get(), &mut errors);
893            if let StatementType::Select { arguments, columns } = q {
894                check_arguments(name, &arguments, "", &mut errors);
895                check_columns(name, &columns, "dt:dt!,t:i64!", &mut errors);
896            } else {
897                println!("{name} should be select");
898                errors += 1;
899            }
900        }
901
902        {
903            let name = "q17";
904            let src = "SELECT CONCAT(?, \"hat\") AS c";
905            let mut issues: Issues<'_> = Issues::new(src);
906            let q = type_statement(&schema, src, &mut issues, &options);
907            check_no_errors(name, src, issues.get(), &mut errors);
908            if let StatementType::Select { arguments, columns } = q {
909                check_arguments(name, &arguments, "any", &mut errors);
910                check_columns(name, &columns, "c:str", &mut errors);
911            } else {
912                println!("{name} should be selsect");
913                errors += 1;
914            }
915        }
916
917        {
918            let name = "q18";
919            let src = "SELECT CAST(\"::0\" AS INET6) AS `id`";
920            let mut issues: Issues<'_> = Issues::new(src);
921            let q = type_statement(&schema, src, &mut issues, &options);
922            check_no_errors(name, src, issues.get(), &mut errors);
923            if let StatementType::Select { arguments, columns } = q {
924                check_arguments(name, &arguments, "", &mut errors);
925                check_columns(name, &columns, "id:str!", &mut errors);
926            } else {
927                println!("{name} should be select");
928                errors += 1;
929            }
930        }
931
932        {
933            let name: &str = "q18";
934            let src = "SELECT SUBSTRING(`cbytes`, 1, 5) AS `k` FROM `t1`";
935            let mut issues: Issues<'_> = Issues::new(src);
936            let q = type_statement(&schema, src, &mut issues, &options);
937            check_no_errors(name, src, issues.get(), &mut errors);
938            if let StatementType::Select { arguments, columns } = q {
939                check_arguments(name, &arguments, "", &mut errors);
940                check_columns(name, &columns, "k:bytes", &mut errors);
941            } else {
942                println!("{name} should be select");
943                errors += 1;
944            }
945        }
946
947        {
948            let name = "q19";
949            let src = "SELECT SUBSTRING(`ctext`, 1, 5) AS `k` FROM `t1`";
950            let mut issues: Issues<'_> = Issues::new(src);
951            let q = type_statement(&schema, src, &mut issues, &options);
952            check_no_errors(name, src, issues.get(), &mut errors);
953            if let StatementType::Select { arguments, columns } = q {
954                check_arguments(name, &arguments, "", &mut errors);
955                check_columns(name, &columns, "k:str!", &mut errors);
956            } else {
957                println!("{name} should be select");
958                errors += 1;
959            }
960        }
961
962        {
963            let name = "q19";
964            let src = "SELECT SUBSTRING(`ctext`, 1, 5) AS `k` FROM `t1`";
965            let mut issues: Issues<'_> = Issues::new(src);
966            let q = type_statement(&schema, src, &mut issues, &options);
967            check_no_errors(name, src, issues.get(), &mut errors);
968            if let StatementType::Select { arguments, columns } = q {
969                check_arguments(name, &arguments, "", &mut errors);
970                check_columns(name, &columns, "k:str!", &mut errors);
971            } else {
972                println!("{name} should be select");
973                errors += 1;
974            }
975        }
976
977        {
978            let name = "q20";
979            let src = "SELECT JSON_QUERY('{ \"A\": 1, \"B\": [2, 3]}', '$.B[1]') AS `k` FROM `t3`";
980            let mut issues: Issues<'_> = Issues::new(src);
981            let q = type_statement(&schema, src, &mut issues, &options);
982            check_no_errors(name, src, issues.get(), &mut errors);
983            if let StatementType::Select { arguments, columns } = q {
984                check_arguments(name, &arguments, "", &mut errors);
985                check_columns(name, &columns, "k:json", &mut errors);
986            } else {
987                println!("{name} should be select");
988                errors += 1;
989            }
990        }
991
992        {
993            let name = "q21";
994            let src = "SELECT JSON_REMOVE('{ \"A\": 1, \"B\": [2, 3]}', '$.B[1]', '$.C[3]') AS `k` FROM `t3`";
995            let mut issues: Issues<'_> = Issues::new(src);
996            let q = type_statement(&schema, src, &mut issues, &options);
997            check_no_errors(name, src, issues.get(), &mut errors);
998            if let StatementType::Select { arguments, columns } = q {
999                check_arguments(name, &arguments, "", &mut errors);
1000                check_columns(name, &columns, "k:json", &mut errors);
1001            } else {
1002                println!("{name} should be select");
1003                errors += 1;
1004            }
1005        }
1006
1007        {
1008            let name = "q22";
1009            let src = "SELECT JSON_OVERLAPS('false', 'false') AS `k` FROM `t3`";
1010            let mut issues: Issues<'_> = Issues::new(src);
1011            let q = type_statement(&schema, src, &mut issues, &options);
1012            check_no_errors(name, src, issues.get(), &mut errors);
1013            if let StatementType::Select { arguments, columns } = q {
1014                check_arguments(name, &arguments, "", &mut errors);
1015                check_columns(name, &columns, "k:b!", &mut errors);
1016            } else {
1017                println!("{name} should be select");
1018                errors += 1;
1019            }
1020        }
1021
1022        {
1023            let name = "q23";
1024            let src = "SELECT JSON_OVERLAPS('false', NULL) AS `k` FROM `t3`";
1025            let mut issues: Issues<'_> = Issues::new(src);
1026            let q = type_statement(&schema, src, &mut issues, &options);
1027            check_no_errors(name, src, issues.get(), &mut errors);
1028            if let StatementType::Select { arguments, columns } = q {
1029                check_arguments(name, &arguments, "", &mut errors);
1030                check_columns(name, &columns, "k:b", &mut errors);
1031            } else {
1032                println!("{name} should be select");
1033                errors += 1;
1034            }
1035        }
1036
1037        {
1038            let name = "q24";
1039            let src = "SELECT JSON_CONTAINS('{\"A\": 0, \"B\": [\"x\", \"y\"]}', '\"x\"', '$.B') AS `k` FROM `t3`";
1040            let mut issues: Issues<'_> = Issues::new(src);
1041            let q = type_statement(&schema, src, &mut issues, &options);
1042            check_no_errors(name, src, issues.get(), &mut errors);
1043            if let StatementType::Select { arguments, columns } = q {
1044                check_arguments(name, &arguments, "", &mut errors);
1045                check_columns(name, &columns, "k:b!", &mut errors);
1046            } else {
1047                println!("{name} should be select");
1048                errors += 1;
1049            }
1050        }
1051
1052        {
1053            let name = "q25";
1054            let src = "SELECT JSON_CONTAINS('{\"A\": 0, \"B\": [\"x\", \"y\"]}', NULL, '$.A') AS `k` FROM `t3`";
1055            let mut issues: Issues<'_> = Issues::new(src);
1056            let q = type_statement(&schema, src, &mut issues, &options);
1057            check_no_errors(name, src, issues.get(), &mut errors);
1058            if let StatementType::Select { arguments, columns } = q {
1059                check_arguments(name, &arguments, "", &mut errors);
1060                check_columns(name, &columns, "k:b", &mut errors);
1061            } else {
1062                println!("{name} should be select");
1063                errors += 1;
1064            }
1065        }
1066
1067        {
1068            let name = "q26";
1069            let src = "SELECT `id` FROM `t1` FORCE INDEX (`hat`)";
1070            let mut issues: Issues<'_> = Issues::new(src);
1071            type_statement(&schema, src, &mut issues, &options);
1072            if issues.is_ok() {
1073                println!("{name} should fail");
1074                errors += 1;
1075            }
1076        }
1077
1078        {
1079            let name = "q27";
1080            let src = "SELECT `id` FROM `t1` USE INDEX (`hat2`)";
1081            let mut issues: Issues<'_> = Issues::new(src);
1082            let q = type_statement(&schema, src, &mut issues, &options);
1083            check_no_errors(name, src, issues.get(), &mut errors);
1084            if let StatementType::Select { arguments, columns } = q {
1085                check_arguments(name, &arguments, "", &mut errors);
1086                check_columns(name, &columns, "id:i32!", &mut errors);
1087            } else {
1088                println!("{name} should be select");
1089                errors += 1;
1090            }
1091        }
1092
1093        {
1094            let name = "q28";
1095            let src = "INSERT INTO t5 (`a`) VALUES (44)";
1096            check_no_errors(name, src, issues.get(), &mut errors);
1097        }
1098
1099        {
1100            let name = "q29";
1101            let src = "INSERT INTO t5 (`a`, `b`, `c`) VALUES (?, ?)";
1102            let mut issues: Issues<'_> = Issues::new(src);
1103            type_statement(&schema, src, &mut issues, &options);
1104            if issues.is_ok() {
1105                println!("{name} should fail");
1106                errors += 1;
1107            }
1108        }
1109
1110        {
1111            let name = "q30";
1112            let src = "INSERT INTO t5 (`a`, `b`, `c`) VALUES (?, ?, ?)";
1113            check_no_errors(name, src, issues.get(), &mut errors);
1114        }
1115
1116        {
1117            let name = "q31";
1118            let src = "INSERT INTO t5 (`a`, `b`, `c`) VALUES (?, ?, ?, ?)";
1119            let mut issues: Issues<'_> = Issues::new(src);
1120            type_statement(&schema, src, &mut issues, &options);
1121            if issues.is_ok() {
1122                println!("{name} should fail");
1123                errors += 1;
1124            }
1125        }
1126
1127        {
1128            let name = "q32";
1129            let src = "INSERT INTO t5 (`b`, `c`) VALUES (44, 45)";
1130            let mut issues: Issues<'_> = Issues::new(src);
1131            type_statement(&schema, src, &mut issues, &options);
1132            if issues.is_ok() {
1133                println!("{name} should fail");
1134                errors += 1;
1135            }
1136        }
1137
1138        if errors != 0 {
1139            panic!("{errors} errors in test");
1140        }
1141    }
1142
1143    #[test]
1144    fn postgresql() {
1145        let schema_src = "
1146        BEGIN;
1147
1148        DO $$ BEGIN
1149            CREATE TYPE my_enum AS ENUM (
1150            'V1',
1151            'V2',
1152            'V3'
1153        );
1154        EXCEPTION
1155            WHEN duplicate_object THEN null;
1156        END $$;
1157
1158        CREATE TABLE IF NOT EXISTS t1 (
1159            id bigint NOT NULL PRIMARY KEY GENERATED ALWAYS AS IDENTITY,
1160            path text NOT NULL UNIQUE,
1161            v my_enum NOT NULL,
1162            time timestamptz NOT NULL DEFAULT now(),
1163            old_id bigint,
1164            CONSTRAINT t1__old
1165            FOREIGN KEY(old_id) 
1166            REFERENCES t1(id)
1167            ON DELETE SET NULL
1168        );
1169
1170        CREATE TABLE IF NOT EXISTS t2 (
1171            id bigint NOT NULL PRIMARY KEY
1172        );
1173
1174        DROP INDEX IF EXISTS t2_index;
1175
1176        CREATE INDEX t2_index2 ON t2 (id);
1177
1178        COMMIT;
1179        ";
1180
1181        let options = TypeOptions::new().dialect(SQLDialect::PostgreSQL);
1182        let mut issues = Issues::new(schema_src);
1183        let schema = parse_schemas(schema_src, &mut issues, &options);
1184        let mut errors = 0;
1185        check_no_errors("schema", schema_src, issues.get(), &mut errors);
1186
1187        let options = TypeOptions::new()
1188            .dialect(SQLDialect::PostgreSQL)
1189            .arguments(SQLArguments::Dollar);
1190
1191        {
1192            let name = "q1";
1193            let src = "INSERT INTO t2 (id) SELECT id FROM t1 WHERE path=$1 ON CONFLICT (id) DO NOTHING RETURNING id";
1194            let mut issues = Issues::new(src);
1195            let q = type_statement(&schema, src, &mut issues, &options);
1196            check_no_errors(name, src, issues.get(), &mut errors);
1197            if let StatementType::Insert {
1198                arguments,
1199                returning,
1200                ..
1201            } = q
1202            {
1203                check_arguments(name, &arguments, "str", &mut errors);
1204                check_columns(name, &returning.expect("Returning"), "id:i64!", &mut errors);
1205            } else {
1206                println!("{name} should be select");
1207                errors += 1;
1208            }
1209        }
1210
1211        {
1212            let name = "q2";
1213            let src = "WITH hat AS (DELETE FROM t1 WHERE old_id=42 RETURNING id) INSERT INTO t2 (id) SELECT id FROM hat";
1214            let mut issues = Issues::new(src);
1215            let q = type_statement(&schema, src, &mut issues, &options);
1216            check_no_errors(name, src, issues.get(), &mut errors);
1217
1218            if let StatementType::Insert { arguments, .. } = q {
1219                check_arguments(name, &arguments, "", &mut errors);
1220            } else {
1221                println!("{name} should be select {q:?}");
1222                errors += 1;
1223            }
1224        }
1225
1226        {
1227            let name = "q3";
1228            let src = "INSERT INTO t1 (path) VALUES ('HI')";
1229            let mut issues: Issues<'_> = Issues::new(src);
1230            type_statement(&schema, src, &mut issues, &options);
1231            if issues.is_ok() {
1232                println!("{name} should fail");
1233                errors += 1;
1234            }
1235        }
1236
1237        {
1238            let name = "q4";
1239            let src = "INSERT INTO t1 (path, v) VALUES ('HI', 'V1')";
1240            let mut issues: Issues<'_> = Issues::new(src);
1241            let q = type_statement(&schema, src, &mut issues, &options);
1242            check_no_errors(name, src, issues.get(), &mut errors);
1243
1244            if let StatementType::Insert { arguments, .. } = q {
1245                check_arguments(name, &arguments, "", &mut errors);
1246            } else {
1247                println!("{name} should be insert {q:?}");
1248                errors += 1;
1249            }
1250        }
1251
1252        {
1253            let name = "q5";
1254            let src = "UPDATE t1 SET path='HI' RETURNING id";
1255            let mut issues: Issues<'_> = Issues::new(src);
1256            let q = type_statement(&schema, src, &mut issues, &options);
1257            if let StatementType::Update {
1258                arguments,
1259                returning,
1260                ..
1261            } = q
1262            {
1263                check_arguments(name, &arguments, "", &mut errors);
1264                if returning.is_none() {
1265                    println!("{name} should have returning");
1266                    errors += 1;
1267                }
1268            } else {
1269                println!("{name} should be update {q:?}");
1270                errors += 1;
1271            }
1272        }
1273
1274        if errors != 0 {
1275            panic!("{errors} errors in test");
1276        }
1277    }
1278
1279    #[test]
1280    fn sqlite() {
1281        let schema_src = "
1282         CREATE TABLE IF NOT EXISTS `t1` (
1283            `id` INTEGER NOT NULL PRIMARY KEY,
1284            `sid` TEXT NOT NULL) STRICT;
1285        CREATE UNIQUE INDEX IF NOT EXISTS `t1_sid` ON `t1` (`sid`);
1286        ";
1287
1288        let options = TypeOptions::new().dialect(SQLDialect::Sqlite);
1289        let mut issues = Issues::new(schema_src);
1290        let schema = parse_schemas(schema_src, &mut issues, &options);
1291        let mut errors = 0;
1292        check_no_errors("schema", schema_src, issues.get(), &mut errors);
1293
1294        let options = TypeOptions::new()
1295            .dialect(SQLDialect::Sqlite)
1296            .arguments(SQLArguments::QuestionMark);
1297
1298        {
1299            let name = "q1";
1300            let src = "INSERT INTO `t1` (`sid`) VALUES (?)";
1301            let mut issues = Issues::new(src);
1302            let q = type_statement(&schema, src, &mut issues, &options);
1303            check_no_errors(name, src, issues.get(), &mut errors);
1304            if !matches!(q, StatementType::Insert { .. }) {
1305                println!("{name} should be select");
1306                errors += 1;
1307            }
1308        }
1309
1310        if errors != 0 {
1311            panic!("{errors} errors in test");
1312        }
1313    }
1314}