1pub mod ast;
28pub mod identifier;
29mod parsing;
30mod ptr;
31pub mod syntax_error;
32mod syntax_node;
33mod token_text;
34mod validation;
35
36#[cfg(test)]
37mod test;
38
39use std::{marker::PhantomData, sync::Arc};
40
41pub use squawk_parser::SyntaxKind;
42
43use ast::AstNode;
44pub use ptr::{AstPtr, SyntaxNodePtr};
45use rowan::GreenNode;
46use syntax_error::SyntaxError;
47pub use syntax_node::{SyntaxElement, SyntaxNode, SyntaxToken};
48pub use token_text::TokenText;
49
50#[derive(Debug, PartialEq, Eq)]
56pub struct Parse<T> {
57 green: GreenNode,
58 errors: Option<Arc<[SyntaxError]>>,
59 _ty: PhantomData<fn() -> T>,
60}
61
62impl<T> Clone for Parse<T> {
63 fn clone(&self) -> Parse<T> {
64 Parse {
65 green: self.green.clone(),
66 errors: self.errors.clone(),
67 _ty: PhantomData,
68 }
69 }
70}
71
72impl<T> Parse<T> {
73 fn new(green: GreenNode, errors: Vec<SyntaxError>) -> Parse<T> {
74 Parse {
75 green,
76 errors: if errors.is_empty() {
77 None
78 } else {
79 Some(errors.into())
80 },
81 _ty: PhantomData,
82 }
83 }
84
85 pub fn syntax_node(&self) -> SyntaxNode {
86 SyntaxNode::new_root(self.green.clone())
87 }
88
89 pub fn errors(&self) -> Vec<SyntaxError> {
90 let mut errors = if let Some(e) = self.errors.as_deref() {
91 e.to_vec()
92 } else {
93 vec![]
94 };
95 validation::validate(&self.syntax_node(), &mut errors);
96 errors
97 }
98}
99
100impl<T: AstNode> Parse<T> {
101 pub fn to_syntax(self) -> Parse<SyntaxNode> {
103 Parse {
104 green: self.green,
105 errors: self.errors,
106 _ty: PhantomData,
107 }
108 }
109
110 pub fn tree(&self) -> T {
117 T::cast(self.syntax_node()).unwrap()
118 }
119
120 pub fn ok(self) -> Result<T, Vec<SyntaxError>> {
122 match self.errors() {
123 errors if !errors.is_empty() => Err(errors),
124 _ => Ok(self.tree()),
125 }
126 }
127}
128
129impl Parse<SyntaxNode> {
130 pub fn cast<N: AstNode>(self) -> Option<Parse<N>> {
131 if N::cast(self.syntax_node()).is_some() {
132 Some(Parse {
133 green: self.green,
134 errors: self.errors,
135 _ty: PhantomData,
136 })
137 } else {
138 None
139 }
140 }
141}
142
143pub use crate::ast::SourceFile;
145
146impl SourceFile {
147 pub fn parse(text: &str) -> Parse<SourceFile> {
148 let (green, errors) = parsing::parse_text(text);
149 let root = SyntaxNode::new_root(green.clone());
150
151 assert_eq!(root.kind(), SyntaxKind::SOURCE_FILE);
152 Parse::new(green, errors)
153 }
154}
155
156#[macro_export]
171macro_rules! match_ast {
172 (match $node:ident { $($tt:tt)* }) => { $crate::match_ast!(match ($node) { $($tt)* }) };
173
174 (match ($node:expr) {
175 $( $( $path:ident )::+ ($it:pat) => $res:expr, )*
176 _ => $catch_all:expr $(,)?
177 }) => {{
178 $( if let Some($it) = $($path::)+cast($node.clone()) { $res } else )*
179 { $catch_all }
180 }};
181}
182
183#[test]
186fn api_walkthrough() {
187 use ast::SourceFile;
188 use rowan::{Direction, NodeOrToken, SyntaxText, TextRange, WalkEvent};
189 use std::fmt::Write;
190
191 let source_code = "
192 create function foo(p int8)
193 returns int
194 as 'select 1 + 1'
195 language sql;
196 ";
197 let parse = SourceFile::parse(source_code);
202 assert!(parse.errors().is_empty());
203
204 let file: SourceFile = parse.tree();
207
208 let mut func = None;
211 for stmt in file.stmts() {
212 match stmt {
213 ast::Stmt::CreateFunction(f) => func = Some(f),
214 _ => unreachable!(),
215 }
216 }
217 let func: ast::CreateFunction = func.unwrap();
218
219 let path: Option<ast::Path> = func.path();
225 let name: ast::Name = path.unwrap().segment().unwrap().name().unwrap();
226 assert_eq!(name.text(), "foo");
227
228 let ret_type: Option<ast::RetType> = func.ret_type();
230 let r_ty = &ret_type.unwrap().ty().unwrap();
231 let type_: &ast::PathType = match &r_ty {
232 ast::Type::PathType(r) => r,
233 _ => unreachable!(),
234 };
235 let type_path: ast::Path = type_.path().unwrap();
236 assert_eq!(type_path.syntax().to_string(), "int");
237
238 let param_list: ast::ParamList = func.param_list().unwrap();
240 let param: ast::Param = param_list.params().next().unwrap();
241
242 let param_name: ast::Name = param.name().unwrap();
243 assert_eq!(param_name.syntax().to_string(), "p");
244
245 let param_ty: ast::Type = param.ty().unwrap();
246 assert_eq!(param_ty.syntax().to_string(), "int8");
247
248 let func_option_list: ast::FuncOptionList = func.option_list().unwrap();
249
250 let func_option = func_option_list.options().next().unwrap();
255 let option: &ast::AsFuncOption = match &func_option {
256 ast::FuncOption::AsFuncOption(o) => o,
257 _ => unreachable!(),
258 };
259 let definition: ast::Literal = option.definition().unwrap();
260 assert_eq!(definition.syntax().to_string(), "'select 1 + 1'");
261
262 let func_option_syntax = func_option.syntax();
265
266 assert!(func_option_syntax == option.syntax());
268
269 let _expr: ast::FuncOption = match ast::FuncOption::cast(func_option_syntax.clone()) {
271 Some(e) => e,
272 None => unreachable!(),
273 };
274
275 assert_eq!(func_option_syntax.kind(), SyntaxKind::AS_FUNC_OPTION);
277
278 assert_eq!(
280 func_option_syntax.text_range(),
281 TextRange::new(65.into(), 82.into())
282 );
283
284 let text: SyntaxText = func_option_syntax.text();
287 assert_eq!(text.to_string(), "as 'select 1 + 1'");
288
289 assert_eq!(
291 func_option_syntax.parent().as_ref(),
292 Some(func_option_list.syntax())
293 );
294 assert_eq!(
295 param_list
296 .syntax()
297 .first_child_or_token()
298 .map(|it| it.kind()),
299 Some(SyntaxKind::L_PAREN)
300 );
301 assert_eq!(
302 func_option_syntax
303 .next_sibling_or_token()
304 .map(|it| it.kind()),
305 Some(SyntaxKind::WHITESPACE)
306 );
307
308 let f = func_option_syntax
310 .ancestors()
311 .find_map(ast::CreateFunction::cast);
312 assert_eq!(f, Some(func));
313 assert!(
314 param
315 .syntax()
316 .siblings_with_tokens(Direction::Next)
317 .any(|it| it.kind() == SyntaxKind::R_PAREN)
318 );
319 assert_eq!(
320 func_option_syntax.descendants_with_tokens().count(),
321 5, );
325
326 let mut buf = String::new();
328 let mut indent = 0;
329 for event in func_option_syntax.preorder_with_tokens() {
330 match event {
331 WalkEvent::Enter(node) => {
332 let text = match &node {
333 NodeOrToken::Node(it) => it.text().to_string(),
334 NodeOrToken::Token(it) => it.text().to_owned(),
335 };
336 buf.write_fmt(format_args!(
337 "{:indent$}{:?} {:?}\n",
338 " ",
339 text,
340 node.kind(),
341 indent = indent
342 ))
343 .unwrap();
344 indent += 2;
345 }
346 WalkEvent::Leave(_) => indent -= 2,
347 }
348 }
349 assert_eq!(indent, 0);
350 assert_eq!(
351 buf.trim(),
352 r#"
353"as 'select 1 + 1'" AS_FUNC_OPTION
354 "as" AS_KW
355 " " WHITESPACE
356 "'select 1 + 1'" LITERAL
357 "'select 1 + 1'" STRING
358 "#
359 .trim()
360 );
361
362 let exprs_cast: Vec<String> = file
369 .syntax()
370 .descendants()
371 .filter_map(ast::FuncOption::cast)
372 .map(|expr| expr.syntax().text().to_string())
373 .collect();
374
375 let mut exprs_visit = Vec::new();
377 for node in file.syntax().descendants() {
378 match_ast! {
379 match node {
380 ast::FuncOption(it) => {
381 let res = it.syntax().text().to_string();
382 exprs_visit.push(res);
383 },
384 _ => (),
385 }
386 }
387 }
388 assert_eq!(exprs_cast, exprs_visit);
389}
390
391#[test]
392fn create_table() {
393 use insta::assert_debug_snapshot;
394
395 let source_code = "
396 create table users (
397 id int8 primary key,
398 name varchar(255) not null,
399 email text,
400 created_at timestamp default now()
401 );
402
403 create table posts (
404 id serial primary key,
405 title varchar(500),
406 content text,
407 user_id int8 references users(id)
408 );
409 ";
410
411 let parse = SourceFile::parse(source_code);
412 assert!(parse.errors().is_empty());
413 let file: SourceFile = parse.tree();
414
415 let mut tables: Vec<(String, Vec<(String, String)>)> = vec![];
416
417 for stmt in file.stmts() {
418 if let ast::Stmt::CreateTable(create_table) = stmt {
419 let table_name = create_table.path().unwrap().syntax().to_string();
420 let mut columns = vec![];
421 for arg in create_table.table_arg_list().unwrap().args() {
422 match arg {
423 ast::TableArg::Column(column) => {
424 let column_name = column.name().unwrap();
425 let column_type = column.ty().unwrap();
426 columns.push((
427 column_name.syntax().to_string(),
428 column_type.syntax().to_string(),
429 ));
430 }
431 ast::TableArg::TableConstraint(_) | ast::TableArg::LikeClause(_) => (),
432 }
433 }
434 tables.push((table_name, columns));
435 }
436 }
437
438 assert_debug_snapshot!(tables, @r#"
439 [
440 (
441 "users",
442 [
443 (
444 "id",
445 "int8",
446 ),
447 (
448 "name",
449 "varchar(255)",
450 ),
451 (
452 "email",
453 "text",
454 ),
455 (
456 "created_at",
457 "timestamp",
458 ),
459 ],
460 ),
461 (
462 "posts",
463 [
464 (
465 "id",
466 "serial",
467 ),
468 (
469 "title",
470 "varchar(500)",
471 ),
472 (
473 "content",
474 "text",
475 ),
476 (
477 "user_id",
478 "int8",
479 ),
480 ],
481 ),
482 ]
483 "#)
484}
485
486#[test]
487fn bin_expr() {
488 use insta::assert_debug_snapshot;
489
490 let source_code = "select 1 is not null;";
491 let parse = SourceFile::parse(source_code);
492 assert!(parse.errors().is_empty());
493 let file: SourceFile = parse.tree();
494
495 let ast::Stmt::Select(select) = file.stmts().next().unwrap() else {
496 unreachable!()
497 };
498
499 let target_list = select.select_clause().unwrap().target_list().unwrap();
500 let target = target_list.targets().next().unwrap();
501 let ast::Expr::BinExpr(bin_expr) = target.expr().unwrap() else {
502 unreachable!()
503 };
504
505 let lhs = bin_expr.lhs();
506 let op = bin_expr.op();
507 let rhs = bin_expr.rhs();
508
509 assert_debug_snapshot!(lhs, @r#"
510 Some(
511 Literal(
512 Literal {
513 syntax: LITERAL@7..8
514 INT_NUMBER@7..8 "1"
515 ,
516 },
517 ),
518 )
519 "#);
520 assert_debug_snapshot!(op, @r#"
521 Some(
522 IsNot(
523 IsNot {
524 syntax: IS_NOT@9..15
525 IS_KW@9..11 "is"
526 WHITESPACE@11..12 " "
527 NOT_KW@12..15 "not"
528 ,
529 },
530 ),
531 )
532 "#);
533 assert_debug_snapshot!(rhs, @r#"
534 Some(
535 Literal(
536 Literal {
537 syntax: LITERAL@16..20
538 NULL_KW@16..20 "null"
539 ,
540 },
541 ),
542 )
543 "#);
544}