Skip to main content

sql_fun_sqlast/
lib.rs

1#![deny(missing_docs)]
2#![allow(clippy::pedantic)]
3#![doc = include_str!("lib.md")]
4
5pub mod sem;
6pub mod syn;
7
8mod schema_dump_file;
9mod sql_file;
10mod trie_map;
11
12use miette::LabeledSpan;
13use sql_fun_core::IVec;
14
15use self::sql_file::read_sql_file;
16pub use self::trie_map::TrieMap;
17
18use std::{io::SeekFrom, string::String as StdString};
19
20use crate::{
21    sem::{ParseContext, SemAst},
22    syn::ParseResult,
23};
24
25/// Semantic AST and parse context result
26pub struct AstAndContextPair<TParseContext>(SemAst, TParseContext)
27where
28    TParseContext: ParseContext;
29
30impl<TParseContext> AstAndContextPair<TParseContext>
31where
32    TParseContext: ParseContext,
33{
34    /// create instance
35    pub fn new(ast: SemAst, context: TParseContext) -> Self {
36        Self(ast, context)
37    }
38}
39
40/// parse SQL statement
41///
42/// # Errors
43///
44/// - [`pg_query::Error`] : syntax error
45///
46pub fn parse(sql: &str) -> Result<ParseResult, pg_query::Error> {
47    let parse_result = ::pg_query::parse(sql)?;
48    Ok(ParseResult::from(parse_result.protobuf))
49}
50
51/// scan SQL Statement
52pub fn scan(sql: &str) -> Result<IVec<crate::syn::ScanToken>, pg_query::Error> {
53    let scan_result = ::pg_query::scan(sql)?;
54    let mut tokens = Vec::new();
55    for token in scan_result.tokens {
56        tokens.push(crate::syn::ScanToken::from(token));
57    }
58    Ok(tokens.into())
59}
60
61fn offset_in_string(haystack: &StdString, needle: &str) -> Option<usize> {
62    let base = haystack.as_ptr() as usize;
63    let ptr = needle.as_ptr() as usize;
64
65    if ptr >= base && ptr <= base + haystack.len() {
66        Some(ptr - base)
67    } else {
68        None // needleがhaystack由来ではない
69    }
70}
71
72///  offset and length for SQL statement
73#[derive(Debug, Clone, Copy, Default)]
74pub struct StringSpan {
75    offset: usize,
76    len: usize,
77}
78
79impl StringSpan {
80    /// get content span
81    ///
82    /// # Panics
83    ///
84    /// content not in container range
85    #[must_use]
86    pub fn from_str_in_str(container: &StdString, content: &str) -> Self {
87        let offset = offset_in_string(container, content).expect("content in container");
88        let len = content.len();
89        Self { offset, len }
90    }
91
92    /// seek position for this string span
93    pub fn seek_pos(&self) -> SeekFrom {
94        SeekFrom::Start(self.offset as u64)
95    }
96
97    /// get length of this span
98    pub fn len(&self) -> usize {
99        self.len
100    }
101
102    /// Return true when len equals 0
103    pub fn is_empty(&self) -> bool {
104        self.len == 0
105    }
106
107    fn new_labeled_span(&self, label: &str) -> LabeledSpan {
108        LabeledSpan::new(Some(String::from(label)), self.offset, self.len)
109    }
110
111    fn from_scan_token(tok: &crate::syn::ScanToken) -> Self {
112        Self {
113            offset: tok.get_start() as usize,
114            len: (tok.get_end() - tok.get_start()) as usize,
115        }
116    }
117
118    /// get end position
119    pub fn end_pos(&self) -> usize {
120        self.offset + self.len
121    }
122
123    /// merge two span
124    pub fn extend(&mut self, other: &Self) {
125        self.offset = std::cmp::min(self.offset, other.offset);
126        self.len = std::cmp::max(self.end_pos(), other.end_pos()) - self.offset;
127    }
128}
129
130pub use self::schema_dump_file::{ParseDumpFileError, ParsedSchemaDump, parse_schema_file};
131
132#[cfg(test)]
133pub mod test_helpers;
134
135#[cfg(test)]
136mod tests {
137
138    use std::path::PathBuf;
139
140    use crate::{
141        sem::{BaseContext, BaseParseContext},
142        syn::{ListOpt, Opt},
143    };
144
145    use super::{parse, parse_schema_file};
146    use clap::Parser;
147    use sql_fun_core::SqlFunArgs;
148    use testresult::TestResult;
149
150    pub struct EnableStackOverflowBacktrace {}
151
152    #[rstest::fixture]
153    pub fn enable_stack_overflow_backtrace() -> EnableStackOverflowBacktrace {
154        #[expect(unsafe_code)]
155        unsafe {
156            backtrace_on_stack_overflow::enable()
157        };
158        EnableStackOverflowBacktrace {}
159    }
160
161    #[rstest::fixture]
162    pub fn context_args() -> SqlFunArgs {
163        SqlFunArgs::try_parse_from(vec![
164            "sqlfun",
165            "--metadata-file",
166            "sql_fun.metadata.toml",
167            "--sql-fun-home",
168            env!("SQL_FUN_HOME"),
169            "subcmd",
170        ])
171        .unwrap()
172    }
173
174    #[ctor::ctor]
175    fn init_tracing() {
176        use tracing_subscriber::{
177            EnvFilter,
178            fmt::format::{FmtSpan, debug_fn},
179        };
180
181        let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("debug"));
182
183        let _ = tracing_subscriber::fmt()
184            .with_env_filter(filter)
185            .with_test_writer()
186            .with_span_events(FmtSpan::ACTIVE)
187            .fmt_fields(debug_fn(|writer, field, value| {
188                if field.name() == "message" {
189                    use core::fmt::Write as _;
190                    write!(writer, "{value:?}")
191                } else {
192                    Ok(())
193                }
194            }))
195            .try_init();
196    }
197
198    #[test]
199    fn test_parse_simple_query() -> TestResult {
200        let result = parse("select * from users where id=0")?;
201        let Some(relname) = result
202            .get_stmts()
203            .get(0)
204            .get_stmt()
205            .as_select_stmt()
206            .get_from_clause()
207            .get(0)
208            .as_range_var()
209            .get_relname()
210        else {
211            panic!("select relname failed returns None")
212        };
213
214        assert_eq!(&relname, "users");
215
216        Ok(())
217    }
218
219    #[test]
220    fn test_create_composit_type() -> TestResult {
221        let result = parse("CREATE TYPE compfoo AS (f1 int, f2 text);")?;
222        let stmt = result.get_stmts().get(0).get_stmt();
223
224        let Some(ct) = stmt.as_composite_type_stmt().as_inner() else {
225            eprintln!("{stmt:?}");
226            panic!();
227        };
228
229        eprintln!("{ct:?}");
230        Ok(())
231    }
232
233    #[test]
234    fn test_create_range_type() -> TestResult {
235        let result = parse(
236            "CREATE TYPE float8_range AS RANGE (subtype = float8, subtype_diff = float8mi);",
237        )?;
238        let stmt = result.get_stmts().get(0).get_stmt();
239
240        let Some(ct) = stmt.as_create_range_stmt().as_inner() else {
241            eprintln!("{stmt:?}");
242            panic!();
243        };
244
245        eprintln!("{ct:?}");
246        Ok(())
247    }
248
249    #[test]
250    fn test_create_base_type() -> TestResult {
251        let result = parse(
252            "CREATE TYPE box (
253    INTERNALLENGTH = 16,
254    INPUT = my_box_in_function,
255    OUTPUT = my_box_out_function
256);",
257        )?;
258        let stmt = result.get_stmts().get(0).get_stmt();
259
260        let Some(ct) = stmt.as_define_stmt().as_inner() else {
261            eprintln!("{stmt:?}");
262            panic!();
263        };
264
265        eprintln!("{ct:?}");
266        Ok(())
267    }
268
269    #[ignore]
270    #[rstest::rstest]
271    fn parse_adventure_works_schema(
272        context_args: SqlFunArgs,
273        _enable_stack_overflow_backtrace: EnableStackOverflowBacktrace,
274    ) -> TestResult {
275        let mut analyze_context = BaseContext::new(context_args.clone())?;
276
277        let builtin_context = BaseContext::new(context_args.clone())?;
278        let home_path = PathBuf::from(env!("SQL_FUN_HOME"));
279        let builtin = home_path.join("postgres/17/schema.sql");
280        let builtin = parse_schema_file(&builtin, builtin_context)?;
281        analyze_context.extend(Box::new(builtin));
282
283        let tablefunc_ext = home_path.join("postgres/17/extension/tablefunc--1.0.sql");
284        let input_context = BaseContext::new(context_args.clone())?;
285        let tablefunc_ext = parse_schema_file(&tablefunc_ext, input_context)?;
286
287        analyze_context.extend(Box::new(tablefunc_ext));
288        let file = PathBuf::from("../examples/adventure-works/schema.develop.sql");
289        let _dump = parse_schema_file(&file, analyze_context)?;
290        Ok(())
291    }
292}