1pub mod ast;
28pub mod identifier;
29mod parsing;
30pub mod syntax_error;
31mod syntax_node;
32mod token_text;
33mod validation;
34
35#[cfg(test)]
36mod test;
37
38use std::{marker::PhantomData, sync::Arc};
39
40pub use squawk_parser::SyntaxKind;
41
42use ast::AstNode;
43use rowan::GreenNode;
44use syntax_error::SyntaxError;
45pub use syntax_node::{SyntaxNode, SyntaxToken};
46pub use token_text::TokenText;
47
48#[derive(Debug, PartialEq, Eq)]
54pub struct Parse<T> {
55 green: GreenNode,
56 errors: Option<Arc<[SyntaxError]>>,
57 _ty: PhantomData<fn() -> T>,
58}
59
60impl<T> Clone for Parse<T> {
61 fn clone(&self) -> Parse<T> {
62 Parse {
63 green: self.green.clone(),
64 errors: self.errors.clone(),
65 _ty: PhantomData,
66 }
67 }
68}
69
70impl<T> Parse<T> {
71 fn new(green: GreenNode, errors: Vec<SyntaxError>) -> Parse<T> {
72 Parse {
73 green,
74 errors: if errors.is_empty() {
75 None
76 } else {
77 Some(errors.into())
78 },
79 _ty: PhantomData,
80 }
81 }
82
83 pub fn syntax_node(&self) -> SyntaxNode {
84 SyntaxNode::new_root(self.green.clone())
85 }
86
87 pub fn errors(&self) -> Vec<SyntaxError> {
88 let mut errors = if let Some(e) = self.errors.as_deref() {
89 e.to_vec()
90 } else {
91 vec![]
92 };
93 validation::validate(&self.syntax_node(), &mut errors);
94 errors
95 }
96}
97
98impl<T: AstNode> Parse<T> {
99 pub fn to_syntax(self) -> Parse<SyntaxNode> {
101 Parse {
102 green: self.green,
103 errors: self.errors,
104 _ty: PhantomData,
105 }
106 }
107
108 pub fn tree(&self) -> T {
115 T::cast(self.syntax_node()).unwrap()
116 }
117
118 pub fn ok(self) -> Result<T, Vec<SyntaxError>> {
120 match self.errors() {
121 errors if !errors.is_empty() => Err(errors),
122 _ => Ok(self.tree()),
123 }
124 }
125}
126
127impl Parse<SyntaxNode> {
128 pub fn cast<N: AstNode>(self) -> Option<Parse<N>> {
129 if N::cast(self.syntax_node()).is_some() {
130 Some(Parse {
131 green: self.green,
132 errors: self.errors,
133 _ty: PhantomData,
134 })
135 } else {
136 None
137 }
138 }
139}
140
141pub use crate::ast::SourceFile;
143
144impl SourceFile {
145 pub fn parse(text: &str) -> Parse<SourceFile> {
146 let (green, errors) = parsing::parse_text(text);
147 let root = SyntaxNode::new_root(green.clone());
148
149 assert_eq!(root.kind(), SyntaxKind::SOURCE_FILE);
150 Parse::new(green, errors)
151 }
152}
153
154#[macro_export]
169macro_rules! match_ast {
170 (match $node:ident { $($tt:tt)* }) => { $crate::match_ast!(match ($node) { $($tt)* }) };
171
172 (match ($node:expr) {
173 $( $( $path:ident )::+ ($it:pat) => $res:expr, )*
174 _ => $catch_all:expr $(,)?
175 }) => {{
176 $( if let Some($it) = $($path::)+cast($node.clone()) { $res } else )*
177 { $catch_all }
178 }};
179}
180
181#[test]
184fn api_walkthrough() {
185 use ast::SourceFile;
186 use rowan::{Direction, NodeOrToken, SyntaxText, TextRange, WalkEvent};
187 use std::fmt::Write;
188
189 let source_code = "
190 create function foo(p int8)
191 returns int
192 as 'select 1 + 1'
193 language sql;
194 ";
195 let parse = SourceFile::parse(source_code);
200 assert!(parse.errors().is_empty());
201
202 let file: SourceFile = parse.tree();
205
206 let mut func = None;
209 for stmt in file.stmts() {
210 match stmt {
211 ast::Stmt::CreateFunction(f) => func = Some(f),
212 _ => unreachable!(),
213 }
214 }
215 let func: ast::CreateFunction = func.unwrap();
216
217 let path: Option<ast::Path> = func.path();
223 let name: ast::Name = path.unwrap().segment().unwrap().name().unwrap();
224 assert_eq!(name.text(), "foo");
225
226 let ret_type: Option<ast::RetType> = func.ret_type();
228 let r_ty = &ret_type.unwrap().ty().unwrap();
229 let type_: &ast::PathType = match &r_ty {
230 ast::Type::PathType(r) => r,
231 _ => unreachable!(),
232 };
233 let type_path: ast::Path = type_.path().unwrap();
234 assert_eq!(type_path.syntax().to_string(), "int");
235
236 let param_list: ast::ParamList = func.param_list().unwrap();
238 let param: ast::Param = param_list.params().next().unwrap();
239
240 let param_name: ast::Name = param.name().unwrap();
241 assert_eq!(param_name.syntax().to_string(), "p");
242
243 let param_ty: ast::Type = param.ty().unwrap();
244 assert_eq!(param_ty.syntax().to_string(), "int8");
245
246 let func_option_list: ast::FuncOptionList = func.option_list().unwrap();
247
248 let func_option = func_option_list.options().next().unwrap();
253 let option: &ast::AsFuncOption = match &func_option {
254 ast::FuncOption::AsFuncOption(o) => o,
255 _ => unreachable!(),
256 };
257 let definition: ast::Literal = option.definition().unwrap();
258 assert_eq!(definition.syntax().to_string(), "'select 1 + 1'");
259
260 let func_option_syntax: &SyntaxNode = func_option.syntax();
263
264 assert!(func_option_syntax == option.syntax());
266
267 let _expr: ast::FuncOption = match ast::FuncOption::cast(func_option_syntax.clone()) {
269 Some(e) => e,
270 None => unreachable!(),
271 };
272
273 assert_eq!(func_option_syntax.kind(), SyntaxKind::AS_FUNC_OPTION);
275
276 assert_eq!(
278 func_option_syntax.text_range(),
279 TextRange::new(65.into(), 82.into())
280 );
281
282 let text: SyntaxText = func_option_syntax.text();
285 assert_eq!(text.to_string(), "as 'select 1 + 1'");
286
287 assert_eq!(
289 func_option_syntax.parent().as_ref(),
290 Some(func_option_list.syntax())
291 );
292 assert_eq!(
293 param_list
294 .syntax()
295 .first_child_or_token()
296 .map(|it| it.kind()),
297 Some(SyntaxKind::L_PAREN)
298 );
299 assert_eq!(
300 func_option_syntax
301 .next_sibling_or_token()
302 .map(|it| it.kind()),
303 Some(SyntaxKind::WHITESPACE)
304 );
305
306 let f = func_option_syntax
308 .ancestors()
309 .find_map(ast::CreateFunction::cast);
310 assert_eq!(f, Some(func));
311 assert!(
312 param
313 .syntax()
314 .siblings_with_tokens(Direction::Next)
315 .any(|it| it.kind() == SyntaxKind::R_PAREN)
316 );
317 assert_eq!(
318 func_option_syntax.descendants_with_tokens().count(),
319 5, );
323
324 let mut buf = String::new();
326 let mut indent = 0;
327 for event in func_option_syntax.preorder_with_tokens() {
328 match event {
329 WalkEvent::Enter(node) => {
330 let text = match &node {
331 NodeOrToken::Node(it) => it.text().to_string(),
332 NodeOrToken::Token(it) => it.text().to_owned(),
333 };
334 buf.write_fmt(format_args!(
335 "{:indent$}{:?} {:?}\n",
336 " ",
337 text,
338 node.kind(),
339 indent = indent
340 ))
341 .unwrap();
342 indent += 2;
343 }
344 WalkEvent::Leave(_) => indent -= 2,
345 }
346 }
347 assert_eq!(indent, 0);
348 assert_eq!(
349 buf.trim(),
350 r#"
351"as 'select 1 + 1'" AS_FUNC_OPTION
352 "as" AS_KW
353 " " WHITESPACE
354 "'select 1 + 1'" LITERAL
355 "'select 1 + 1'" STRING
356 "#
357 .trim()
358 );
359
360 let exprs_cast: Vec<String> = file
367 .syntax()
368 .descendants()
369 .filter_map(ast::FuncOption::cast)
370 .map(|expr| expr.syntax().text().to_string())
371 .collect();
372
373 let mut exprs_visit = Vec::new();
375 for node in file.syntax().descendants() {
376 match_ast! {
377 match node {
378 ast::FuncOption(it) => {
379 let res = it.syntax().text().to_string();
380 exprs_visit.push(res);
381 },
382 _ => (),
383 }
384 }
385 }
386 assert_eq!(exprs_cast, exprs_visit);
387}
388
389#[test]
390fn create_table() {
391 use insta::assert_debug_snapshot;
392
393 let source_code = "
394 create table users (
395 id int8 primary key,
396 name varchar(255) not null,
397 email text,
398 created_at timestamp default now()
399 );
400
401 create table posts (
402 id serial primary key,
403 title varchar(500),
404 content text,
405 user_id int8 references users(id)
406 );
407 ";
408
409 let parse = SourceFile::parse(source_code);
410 assert!(parse.errors().is_empty());
411 let file: SourceFile = parse.tree();
412
413 let mut tables: Vec<(String, Vec<(String, String)>)> = vec![];
414
415 for stmt in file.stmts() {
416 if let ast::Stmt::CreateTable(create_table) = stmt {
417 let table_name = create_table.path().unwrap().syntax().to_string();
418 let mut columns = vec![];
419 for arg in create_table.table_arg_list().unwrap().args() {
420 match arg {
421 ast::TableArg::Column(column) => {
422 let column_name = column.name().unwrap();
423 let column_type = column.ty().unwrap();
424 columns.push((
425 column_name.syntax().to_string(),
426 column_type.syntax().to_string(),
427 ));
428 }
429 ast::TableArg::TableConstraint(_) | ast::TableArg::LikeClause(_) => (),
430 }
431 }
432 tables.push((table_name, columns));
433 }
434 }
435
436 assert_debug_snapshot!(tables, @r#"
437 [
438 (
439 "users",
440 [
441 (
442 "id",
443 "int8",
444 ),
445 (
446 "name",
447 "varchar(255)",
448 ),
449 (
450 "email",
451 "text",
452 ),
453 (
454 "created_at",
455 "timestamp",
456 ),
457 ],
458 ),
459 (
460 "posts",
461 [
462 (
463 "id",
464 "serial",
465 ),
466 (
467 "title",
468 "varchar(500)",
469 ),
470 (
471 "content",
472 "text",
473 ),
474 (
475 "user_id",
476 "int8",
477 ),
478 ],
479 ),
480 ]
481 "#)
482}