1pub mod ast;
28mod generated;
29mod parsing;
30mod ptr;
31pub mod quote;
32pub mod syntax_error;
33mod syntax_node;
34mod token_text;
35pub mod unescape;
36mod validation;
37
38#[cfg(test)]
39mod test;
40
41use std::{marker::PhantomData, sync::Arc};
42
43pub use squawk_parser::SyntaxKind;
44
45use ast::AstNode;
46pub use ptr::{AstPtr, SyntaxNodePtr};
47use rowan::GreenNode;
48use syntax_error::SyntaxError;
49pub use syntax_node::{SyntaxElement, SyntaxNode, SyntaxToken};
50pub use token_text::TokenText;
51
52#[derive(Debug, PartialEq, Eq)]
58pub struct Parse<T> {
59 green: GreenNode,
60 errors: Option<Arc<[SyntaxError]>>,
61 _ty: PhantomData<fn() -> T>,
62}
63
64impl<T> Clone for Parse<T> {
65 fn clone(&self) -> Parse<T> {
66 Parse {
67 green: self.green.clone(),
68 errors: self.errors.clone(),
69 _ty: PhantomData,
70 }
71 }
72}
73
74impl<T> Parse<T> {
75 fn new(green: GreenNode, errors: Vec<SyntaxError>) -> Parse<T> {
76 Parse {
77 green,
78 errors: if errors.is_empty() {
79 None
80 } else {
81 Some(errors.into())
82 },
83 _ty: PhantomData,
84 }
85 }
86
87 pub fn syntax_node(&self) -> SyntaxNode {
88 SyntaxNode::new_root(self.green.clone())
89 }
90
91 pub fn errors(&self) -> Vec<SyntaxError> {
92 let mut errors = if let Some(e) = self.errors.as_deref() {
93 e.to_vec()
94 } else {
95 vec![]
96 };
97 validation::validate(&self.syntax_node(), &mut errors);
98 errors.sort_by_key(|error| error.range().start());
99 errors
100 }
101}
102
103impl<T: AstNode> Parse<T> {
104 pub fn to_syntax(self) -> Parse<SyntaxNode> {
106 Parse {
107 green: self.green,
108 errors: self.errors,
109 _ty: PhantomData,
110 }
111 }
112
113 pub fn tree(&self) -> T {
120 T::cast(self.syntax_node()).unwrap()
121 }
122
123 pub fn ok(self) -> Result<T, Vec<SyntaxError>> {
125 match self.errors() {
126 errors if !errors.is_empty() => Err(errors),
127 _ => Ok(self.tree()),
128 }
129 }
130}
131
132impl Parse<SyntaxNode> {
133 pub fn cast<N: AstNode>(self) -> Option<Parse<N>> {
134 if N::cast(self.syntax_node()).is_some() {
135 Some(Parse {
136 green: self.green,
137 errors: self.errors,
138 _ty: PhantomData,
139 })
140 } else {
141 None
142 }
143 }
144}
145
146pub use crate::ast::SourceFile;
148
149impl SourceFile {
150 pub fn parse(text: &str) -> Parse<SourceFile> {
151 let (green, errors) = parsing::parse_text(text);
152 let root = SyntaxNode::new_root(green.clone());
153
154 assert_eq!(root.kind(), SyntaxKind::SOURCE_FILE);
155 Parse::new(green, errors)
156 }
157}
158
159#[macro_export]
174macro_rules! match_ast {
175 (match $node:ident { $($tt:tt)* }) => { $crate::match_ast!(match ($node) { $($tt)* }) };
176
177 (match ($node:expr) {
178 $( $( $path:ident )::+ ($it:pat) => $res:expr, )*
179 _ => $catch_all:expr $(,)?
180 }) => {{
181 $( if let Some($it) = $($path::)+cast($node.clone()) { $res } else )*
182 { $catch_all }
183 }};
184}
185
186#[test]
189fn api_walkthrough() {
190 use ast::SourceFile;
191 use rowan::{Direction, NodeOrToken, SyntaxText, TextRange, WalkEvent};
192 use std::fmt::Write;
193
194 let source_code = "
195 create function foo(p int8)
196 returns int
197 as 'select 1 + 1'
198 language sql;
199 ";
200 let parse = SourceFile::parse(source_code);
205 assert!(parse.errors().is_empty());
206
207 let file: SourceFile = parse.tree();
210
211 let mut func = None;
214 for stmt in file.stmts() {
215 match stmt {
216 ast::Stmt::CreateFunction(f) => func = Some(f),
217 _ => unreachable!(),
218 }
219 }
220 let func: ast::CreateFunction = func.unwrap();
221
222 let path: Option<ast::Path> = func.path();
228 let name: ast::Name = path.unwrap().segment().unwrap().name().unwrap();
229 assert_eq!(name.text(), "foo");
230
231 let ret_type: Option<ast::RetType> = func.ret_type();
233 let r_ty = &ret_type.unwrap().ty().unwrap();
234 let type_: &ast::PathType = match &r_ty {
235 ast::Type::PathType(r) => r,
236 _ => unreachable!(),
237 };
238 let type_path: ast::Path = type_.path().unwrap();
239 assert_eq!(type_path.syntax().to_string(), "int");
240
241 let param_list: ast::ParamList = func.param_list().unwrap();
243 let param: ast::Param = param_list.params().next().unwrap();
244
245 let param_name: ast::Name = param.name().unwrap();
246 assert_eq!(param_name.syntax().to_string(), "p");
247
248 let param_ty: ast::Type = param.ty().unwrap();
249 assert_eq!(param_ty.syntax().to_string(), "int8");
250
251 let func_option_list: ast::FuncOptionList = func.option_list().unwrap();
252
253 let func_option = func_option_list.options().next().unwrap();
258 let option: &ast::AsFuncOption = match &func_option {
259 ast::FuncOption::AsFuncOption(o) => o,
260 _ => unreachable!(),
261 };
262 let definition: ast::Literal = option.definition().unwrap();
263 assert_eq!(definition.syntax().to_string(), "'select 1 + 1'");
264
265 let func_option_syntax = func_option.syntax();
268
269 assert!(func_option_syntax == option.syntax());
271
272 let _expr: ast::FuncOption = match ast::FuncOption::cast(func_option_syntax.clone()) {
274 Some(e) => e,
275 None => unreachable!(),
276 };
277
278 assert_eq!(func_option_syntax.kind(), SyntaxKind::AS_FUNC_OPTION);
280
281 assert_eq!(
283 func_option_syntax.text_range(),
284 TextRange::new(65.into(), 82.into())
285 );
286
287 let text: SyntaxText = func_option_syntax.text();
290 assert_eq!(text.to_string(), "as 'select 1 + 1'");
291
292 assert_eq!(
294 func_option_syntax.parent().as_ref(),
295 Some(func_option_list.syntax())
296 );
297 assert_eq!(
298 param_list
299 .syntax()
300 .first_child_or_token()
301 .map(|it| it.kind()),
302 Some(SyntaxKind::L_PAREN)
303 );
304 assert_eq!(
305 func_option_syntax
306 .next_sibling_or_token()
307 .map(|it| it.kind()),
308 Some(SyntaxKind::WHITESPACE)
309 );
310
311 let f = func_option_syntax
313 .ancestors()
314 .find_map(ast::CreateFunction::cast);
315 assert_eq!(f, Some(func));
316 assert!(
317 param
318 .syntax()
319 .siblings_with_tokens(Direction::Next)
320 .any(|it| it.kind() == SyntaxKind::R_PAREN)
321 );
322 assert_eq!(
323 func_option_syntax.descendants_with_tokens().count(),
324 5, );
328
329 let mut buf = String::new();
331 let mut indent = 0;
332 for event in func_option_syntax.preorder_with_tokens() {
333 match event {
334 WalkEvent::Enter(node) => {
335 let text = match &node {
336 NodeOrToken::Node(it) => it.text().to_string(),
337 NodeOrToken::Token(it) => it.text().to_owned(),
338 };
339 buf.write_fmt(format_args!(
340 "{:indent$}{:?} {:?}\n",
341 " ",
342 text,
343 node.kind(),
344 indent = indent
345 ))
346 .unwrap();
347 indent += 2;
348 }
349 WalkEvent::Leave(_) => indent -= 2,
350 }
351 }
352 assert_eq!(indent, 0);
353 assert_eq!(
354 buf.trim(),
355 r#"
356"as 'select 1 + 1'" AS_FUNC_OPTION
357 "as" AS_KW
358 " " WHITESPACE
359 "'select 1 + 1'" LITERAL
360 "'select 1 + 1'" STRING
361 "#
362 .trim()
363 );
364
365 let exprs_cast: Vec<String> = file
372 .syntax()
373 .descendants()
374 .filter_map(ast::FuncOption::cast)
375 .map(|expr| expr.syntax().text().to_string())
376 .collect();
377
378 let mut exprs_visit = Vec::new();
380 for node in file.syntax().descendants() {
381 match_ast! {
382 match node {
383 ast::FuncOption(it) => {
384 let res = it.syntax().text().to_string();
385 exprs_visit.push(res);
386 },
387 _ => (),
388 }
389 }
390 }
391 assert_eq!(exprs_cast, exprs_visit);
392}
393
394#[test]
395fn create_table() {
396 use insta::assert_debug_snapshot;
397
398 let source_code = "
399 create table users (
400 id int8 primary key,
401 name varchar(255) not null,
402 email text,
403 created_at timestamp default now()
404 );
405
406 create table posts (
407 id serial primary key,
408 title varchar(500),
409 content text,
410 user_id int8 references users(id)
411 );
412 ";
413
414 let parse = SourceFile::parse(source_code);
415 assert!(parse.errors().is_empty());
416 let file: SourceFile = parse.tree();
417
418 let mut tables: Vec<(String, Vec<(String, String)>)> = vec![];
419
420 for stmt in file.stmts() {
421 if let ast::Stmt::CreateTable(create_table) = stmt {
422 let table_name = create_table.path().unwrap().syntax().to_string();
423 let mut columns = vec![];
424 for arg in create_table.table_arg_list().unwrap().args() {
425 match arg {
426 ast::TableArg::Column(column) => {
427 let column_name = column.name().unwrap();
428 let column_type = column.ty().unwrap();
429 columns.push((
430 column_name.syntax().to_string(),
431 column_type.syntax().to_string(),
432 ));
433 }
434 ast::TableArg::TableConstraint(_) | ast::TableArg::LikeClause(_) => (),
435 }
436 }
437 tables.push((table_name, columns));
438 }
439 }
440
441 assert_debug_snapshot!(tables, @r#"
442 [
443 (
444 "users",
445 [
446 (
447 "id",
448 "int8",
449 ),
450 (
451 "name",
452 "varchar(255)",
453 ),
454 (
455 "email",
456 "text",
457 ),
458 (
459 "created_at",
460 "timestamp",
461 ),
462 ],
463 ),
464 (
465 "posts",
466 [
467 (
468 "id",
469 "serial",
470 ),
471 (
472 "title",
473 "varchar(500)",
474 ),
475 (
476 "content",
477 "text",
478 ),
479 (
480 "user_id",
481 "int8",
482 ),
483 ],
484 ),
485 ]
486 "#)
487}
488
489#[test]
490fn bin_expr() {
491 use insta::assert_debug_snapshot;
492
493 let source_code = "select 1 is not null;";
494 let parse = SourceFile::parse(source_code);
495 assert!(parse.errors().is_empty());
496 let file: SourceFile = parse.tree();
497
498 let ast::Stmt::Select(select) = file.stmts().next().unwrap() else {
499 unreachable!()
500 };
501
502 let target_list = select.select_clause().unwrap().target_list().unwrap();
503 let target = target_list.targets().next().unwrap();
504 let ast::Expr::BinExpr(bin_expr) = target.expr().unwrap() else {
505 unreachable!()
506 };
507
508 let lhs = bin_expr.lhs();
509 let op = bin_expr.op();
510 let rhs = bin_expr.rhs();
511
512 assert_debug_snapshot!(lhs, @r#"
513 Some(
514 Literal(
515 Literal {
516 syntax: LITERAL@7..8
517 INT_NUMBER@7..8 "1"
518 ,
519 },
520 ),
521 )
522 "#);
523 assert_debug_snapshot!(op, @r#"
524 Some(
525 IsNot(
526 IsNot {
527 syntax: IS_NOT@9..15
528 IS_KW@9..11 "is"
529 WHITESPACE@11..12 " "
530 NOT_KW@12..15 "not"
531 ,
532 },
533 ),
534 )
535 "#);
536 assert_debug_snapshot!(rhs, @r#"
537 Some(
538 Literal(
539 Literal {
540 syntax: LITERAL@16..20
541 NULL_KW@16..20 "null"
542 ,
543 },
544 ),
545 )
546 "#);
547}