Skip to main content

sql_fun_sqlast/sem/
top_level_statement.rs

1use sql_fun_core::IVec;
2
3use crate::{
4    AstAndContextPair,
5    sem::{
6        AnalysisError, ParseContext, analyze_alter_owner, analyze_alter_seq,
7        analyze_alter_statement, analyze_comment, analyze_create_comp_type, analyze_create_domain,
8        analyze_create_enum_type, analyze_create_extension, analyze_create_function,
9        analyze_create_schema, analyze_create_seq, analyze_create_stmt, analyze_create_table_as,
10        analyze_create_view, analyze_index_stmt, analyze_select, analyze_variable_set,
11    },
12    syn::{ListOpt, Opt, ParseResult},
13};
14
15/// analyze top-level SQL statement
16pub fn semantic_analysis<TParseContext>(
17    context: TParseContext,
18    syn: ParseResult,
19    tokens: &IVec<crate::syn::ScanToken>,
20) -> Result<AstAndContextPair<TParseContext>, AnalysisError>
21where
22    TParseContext: ParseContext,
23{
24    let statement = syn.get_stmts().get(0).get_stmt();
25    analyze_node(context, &None, statement, tokens)
26}
27
28/// analyze node
29#[tracing::instrument(skip(context))]
30pub fn analyze_node<TParseContext: ParseContext>(
31    context: TParseContext,
32    parent_schema: &Option<String>,
33    statement: crate::syn::Node,
34    tokens: &IVec<crate::syn::ScanToken>,
35) -> Result<AstAndContextPair<TParseContext>, AnalysisError> {
36    if let Some(create) = statement.as_create_stmt().as_inner() {
37        return analyze_create_stmt(context, parent_schema, create);
38    }
39
40    if let Some(create) = statement.as_create_enum_stmt().as_inner() {
41        return analyze_create_enum_type(context, parent_schema, create);
42    }
43
44    if let Some(comp_type) = statement.as_composite_type_stmt().as_inner() {
45        return analyze_create_comp_type(context, parent_schema, comp_type);
46    }
47
48    if let Some(create_extension) = statement.as_create_extension_stmt().as_inner() {
49        return analyze_create_extension(context, parent_schema, create_extension);
50    }
51
52    if let Some(create_function) = statement.as_create_function_stmt().as_inner() {
53        return analyze_create_function(context, parent_schema, create_function, tokens);
54    }
55
56    if let Some(variable_set) = statement.as_variable_set_stmt().as_inner() {
57        return analyze_variable_set(context, &variable_set, tokens);
58    }
59
60    if let Some(select) = statement.as_select_stmt().as_inner() {
61        return analyze_select(context, select, tokens);
62    }
63
64    if let Some(create_schema) = statement.as_create_schema_stmt().as_inner() {
65        return analyze_create_schema(context, create_schema, tokens);
66    }
67
68    if let Some(alter_owner) = statement.as_alter_owner_stmt().as_inner() {
69        return analyze_alter_owner(context, alter_owner);
70    }
71
72    if let Some(comment) = statement.as_comment_stmt().as_inner() {
73        return analyze_comment(context, comment);
74    }
75
76    if let Some(domain) = statement.as_create_domain_stmt().as_inner() {
77        return analyze_create_domain(context, domain, tokens);
78    }
79
80    if let Some(alter_table) = statement.as_alter_table_stmt().as_inner() {
81        return analyze_alter_statement(context, parent_schema, alter_table, tokens);
82    }
83
84    if let Some(create_view) = statement.as_view_stmt().as_inner() {
85        return analyze_create_view(context, parent_schema, create_view, tokens);
86    }
87
88    if let Some(create_seq) = statement.as_create_seq_stmt().as_inner() {
89        return analyze_create_seq(context, parent_schema, create_seq);
90    }
91
92    if let Some(alter_seq) = statement.as_alter_seq_stmt().as_inner() {
93        return analyze_alter_seq(context, parent_schema, alter_seq);
94    }
95
96    if let Some(create_table_as) = statement.as_create_table_as_stmt().as_inner() {
97        return analyze_create_table_as(context, parent_schema, create_table_as, tokens);
98    }
99
100    if let Some(index_stmt) = statement.as_index_stmt().as_inner() {
101        return analyze_index_stmt(context, parent_schema, index_stmt, tokens);
102    }
103
104    todo!("handle {statement:?}");
105}
106
107#[cfg(test)]
108mod tests {
109    use crate::{
110        AstAndContextPair,
111        sem::{BaseContext, SchemaFileContext, semantic_analysis},
112        tests::context_args,
113    };
114
115    use sql_fun_core::SqlFunArgs;
116    use std::path::Path;
117    use testresult::TestResult;
118
119    #[rstest::rstest]
120    #[case(r#"create table users ( id bigserial not null primary key, name varchar )"#)]
121    #[case(r#"create type user_kind as enum('member', 'guest')"#)]
122    #[case(r#"CREATE TYPE compfoo AS (f1 int, f2 text)"#)]
123    #[case(r#"CREATE EXTENSION hstore;"#)]
124    #[case(
125        r#"CREATE FUNCTION hstore_subscript_handler(internal)
126RETURNS internal
127AS 'MODULE_PATHNAME', 'hstore_subscript_handler'
128LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE;"#
129    )]
130    #[case("SET statement_timeout = 0;")]
131    #[case(r#"CREATE DOMAIN public."AccountNumber" AS character varying(15);"#)]
132    fn test_schema_file_context_accept_ddl(
133        #[case] sql: &str,
134        context_args: SqlFunArgs,
135    ) -> TestResult {
136        let parsed = crate::parse(sql)?;
137        let tokens = crate::scan(sql)?;
138
139        eprintln!("{parsed:?}");
140
141        let base_context = BaseContext::new(context_args)?;
142        let path = Path::new("test.sql");
143        let file_context = SchemaFileContext::new(&path, base_context);
144        let AstAndContextPair(_ast, _context) = semantic_analysis(file_context, parsed, &tokens)?;
145
146        Ok(())
147    }
148}