sql_parse/
statement.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::{boxed::Box, vec::Vec};
14
15use crate::{
16    alter::{parse_alter, AlterTable},
17    create::{
18        parse_create, CreateFunction, CreateIndex, CreateTable, CreateTrigger, CreateTypeEnum,
19        CreateView,
20    },
21    delete::{parse_delete, Delete},
22    drop::{
23        parse_drop, DropDatabase, DropEvent, DropFunction, DropIndex, DropProcedure, DropServer,
24        DropTable, DropTrigger, DropView,
25    },
26    expression::{parse_expression, Expression},
27    insert_replace::{parse_insert_replace, InsertReplace},
28    keywords::Keyword,
29    lexer::Token,
30    parser::{ParseError, Parser},
31    rename::parse_rename_table,
32    select::{parse_select, OrderFlag, Select},
33    span::OptSpanned,
34    truncate::{parse_truncate_table, TruncateTable},
35    update::{parse_update, Update},
36    with_query::parse_with_query,
37    Identifier, RenameTable, Span, Spanned, WithQuery,
38};
39
40#[derive(Clone, Debug)]
41pub struct Set<'a> {
42    pub set_span: Span,
43    pub values: Vec<(Identifier<'a>, Expression<'a>)>,
44}
45
46impl<'a> Spanned for Set<'a> {
47    fn span(&self) -> Span {
48        self.set_span.join_span(&self.values)
49    }
50}
51
52fn parse_set<'a>(parser: &mut Parser<'a, '_>) -> Result<Set<'a>, ParseError> {
53    let set_span = parser.consume_keyword(Keyword::SET)?;
54    let mut values = Vec::new();
55    loop {
56        let name = parser.consume_plain_identifier()?;
57        parser.consume_token(Token::Eq)?;
58        let val = parse_expression(parser, false)?;
59        values.push((name, val));
60        if parser.skip_token(Token::Comma).is_none() {
61            break;
62        }
63    }
64    Ok(Set { set_span, values })
65}
66
67fn parse_statement_list_inner<'a>(
68    parser: &mut Parser<'a, '_>,
69    out: &mut Vec<Statement<'a>>,
70) -> Result<(), ParseError> {
71    loop {
72        while parser.skip_token(Token::SemiColon).is_some() {}
73        let stdin = match parse_statement(parser)? {
74            Some(v) => {
75                let stdin = v.reads_from_stdin();
76                out.push(v);
77                stdin
78            }
79            None => break,
80        };
81        if !matches!(parser.token, Token::SemiColon) {
82            break;
83        }
84        if stdin {
85            let (s, span) = parser.read_from_stdin_and_next();
86            out.push(Statement::Stdin(s, span));
87        } else {
88            parser.consume_token(Token::SemiColon)?;
89        }
90    }
91    Ok(())
92}
93
94fn parse_statement_list<'a>(
95    parser: &mut Parser<'a, '_>,
96    out: &mut Vec<Statement<'a>>,
97) -> Result<(), ParseError> {
98    let old_delimiter = core::mem::replace(&mut parser.delimiter, Token::SemiColon);
99    let r = parse_statement_list_inner(parser, out);
100    parser.delimiter = old_delimiter;
101    r
102}
103
104fn parse_begin(parser: &mut Parser<'_, '_>) -> Result<Span, ParseError> {
105    parser.consume_keyword(Keyword::BEGIN)
106}
107
108fn parse_end(parser: &mut Parser<'_, '_>) -> Result<Span, ParseError> {
109    parser.consume_keyword(Keyword::END)
110}
111
112fn parse_start<'a>(parser: &mut Parser<'a, '_>) -> Result<Statement<'a>, ParseError> {
113    Ok(Statement::StartTransaction(parser.consume_keywords(&[
114        Keyword::START,
115        Keyword::TRANSACTION,
116    ])?))
117}
118
119fn parse_commit(parser: &mut Parser<'_, '_>) -> Result<Span, ParseError> {
120    parser.consume_keyword(Keyword::COMMIT)
121}
122
123fn parse_block<'a>(parser: &mut Parser<'a, '_>) -> Result<Vec<Statement<'a>>, ParseError> {
124    parser.consume_keyword(Keyword::BEGIN)?;
125    let mut ans = Vec::new();
126    parser.recovered(
127        "'END' | 'EXCEPTION'",
128        &|e| {
129            matches!(
130                e,
131                Token::Ident(_, Keyword::END) | Token::Ident(_, Keyword::EXCEPTION)
132            )
133        },
134        |parser| parse_statement_list(parser, &mut ans),
135    )?;
136    if let Some(_exception_span) = parser.skip_keyword(Keyword::EXCEPTION) {
137        while let Some(_when_span) = parser.skip_keyword(Keyword::WHEN) {
138            parser.consume_plain_identifier()?;
139            parser.consume_keyword(Keyword::THEN)?;
140            parse_expression(parser, true)?;
141            parser.consume_token(Token::SemiColon)?;
142        }
143    }
144    parser.consume_keyword(Keyword::END)?;
145    Ok(ans)
146}
147
148/// Condition in if statement
149#[derive(Clone, Debug)]
150pub struct IfCondition<'a> {
151    /// Span of "ELSEIF" if specified
152    pub elseif_span: Option<Span>,
153    /// Expression that must be true for `then` to be executed
154    pub search_condition: Expression<'a>,
155    /// Span of "THEN"
156    pub then_span: Span,
157    /// List of statement to be executed if `search_condition` is true
158    pub then: Vec<Statement<'a>>,
159}
160
161impl<'a> Spanned for IfCondition<'a> {
162    fn span(&self) -> Span {
163        self.then_span
164            .join_span(&self.elseif_span)
165            .join_span(&self.search_condition)
166            .join_span(&self.then_span)
167            .join_span(&self.then)
168    }
169}
170
171/// If statement
172#[derive(Clone, Debug)]
173pub struct If<'a> {
174    /// Span of "IF"
175    pub if_span: Span,
176    // List of if a then v parts
177    pub conditions: Vec<IfCondition<'a>>,
178    /// Span of "ELSE" and else Statement if specified
179    pub else_: Option<(Span, Vec<Statement<'a>>)>,
180    /// Span of "ENDIF"
181    pub endif_span: Span,
182}
183
184impl<'a> Spanned for If<'a> {
185    fn span(&self) -> Span {
186        self.if_span
187            .join_span(&self.conditions)
188            .join_span(&self.else_)
189            .join_span(&self.endif_span)
190    }
191}
192
193fn parse_if<'a>(parser: &mut Parser<'a, '_>) -> Result<If<'a>, ParseError> {
194    let if_span = parser.consume_keyword(Keyword::IF)?;
195    let mut conditions = Vec::new();
196    let mut else_ = None;
197    parser.recovered(
198        "'END'",
199        &|e| matches!(e, Token::Ident(_, Keyword::END)),
200        |parser| {
201            let search_condition = parse_expression(parser, false)?;
202            let then_span = parser.consume_keyword(Keyword::THEN)?;
203            let mut then = Vec::new();
204            parse_statement_list(parser, &mut then)?;
205            conditions.push(IfCondition {
206                elseif_span: None,
207                search_condition,
208                then_span,
209                then,
210            });
211            while let Some(elseif_span) = parser.skip_keyword(Keyword::ELSEIF) {
212                let search_condition = parse_expression(parser, false)?;
213                let then_span = parser.consume_keyword(Keyword::THEN)?;
214                let mut then = Vec::new();
215                parse_statement_list(parser, &mut then)?;
216                conditions.push(IfCondition {
217                    elseif_span: Some(elseif_span),
218                    search_condition,
219                    then_span,
220                    then,
221                })
222            }
223            if let Some(else_span) = parser.skip_keyword(Keyword::ELSE) {
224                let mut o = Vec::new();
225                parse_statement_list(parser, &mut o)?;
226                else_ = Some((else_span, o));
227            }
228            Ok(())
229        },
230    )?;
231    let endif_span = parser.consume_keywords(&[Keyword::END, Keyword::IF])?;
232    Ok(If {
233        if_span,
234        conditions,
235        else_,
236        endif_span,
237    })
238}
239
240/// Return statement
241#[derive(Clone, Debug)]
242pub struct Return<'a> {
243    /// Span of "Return"
244    pub return_span: Span,
245    pub expr: Expression<'a>,
246}
247
248impl<'a> Spanned for Return<'a> {
249    fn span(&self) -> Span {
250        self.return_span.join_span(&self.expr)
251    }
252}
253
254fn parse_return<'a>(parser: &mut Parser<'a, '_>) -> Result<Return<'a>, ParseError> {
255    let return_span = parser.consume_keyword(Keyword::RETURN)?;
256    let expr = parse_expression(parser, false)?;
257    Ok(Return { return_span, expr })
258}
259
260#[derive(Clone, Debug)]
261pub enum SignalConditionInformationName {
262    ClassOrigin(Span),
263    SubclassOrigin(Span),
264    MessageText(Span),
265    MysqlErrno(Span),
266    ConstraintCatalog(Span),
267    ConstraintSchema(Span),
268    ConstraintName(Span),
269    CatalogName(Span),
270    SchemaName(Span),
271    TableName(Span),
272    ColumnName(Span),
273    CursorName(Span),
274}
275
276impl Spanned for SignalConditionInformationName {
277    fn span(&self) -> Span {
278        match self {
279            SignalConditionInformationName::ClassOrigin(span) => span.clone(),
280            SignalConditionInformationName::SubclassOrigin(span) => span.clone(),
281            SignalConditionInformationName::MessageText(span) => span.clone(),
282            SignalConditionInformationName::MysqlErrno(span) => span.clone(),
283            SignalConditionInformationName::ConstraintCatalog(span) => span.clone(),
284            SignalConditionInformationName::ConstraintSchema(span) => span.clone(),
285            SignalConditionInformationName::ConstraintName(span) => span.clone(),
286            SignalConditionInformationName::CatalogName(span) => span.clone(),
287            SignalConditionInformationName::SchemaName(span) => span.clone(),
288            SignalConditionInformationName::TableName(span) => span.clone(),
289            SignalConditionInformationName::ColumnName(span) => span.clone(),
290            SignalConditionInformationName::CursorName(span) => span.clone(),
291        }
292    }
293}
294
295/// Return statement
296#[derive(Clone, Debug)]
297pub struct Signal<'a> {
298    pub signal_span: Span,
299    pub sqlstate_span: Span,
300    pub value_span: Option<Span>,
301    pub sql_state: Expression<'a>,
302    pub set_span: Option<Span>,
303    pub sets: Vec<(SignalConditionInformationName, Span, Expression<'a>)>,
304}
305
306impl<'a> Spanned for Signal<'a> {
307    fn span(&self) -> Span {
308        self.signal_span
309            .join_span(&self.sqlstate_span)
310            .join_span(&self.value_span)
311            .join_span(&self.sql_state)
312            .join_span(&self.set_span)
313            .join_span(&self.sets)
314    }
315}
316
317fn parse_signal<'a>(parser: &mut Parser<'a, '_>) -> Result<Signal<'a>, ParseError> {
318    let signal_span = parser.consume_keyword(Keyword::SIGNAL)?;
319    let sqlstate_span = parser.consume_keyword(Keyword::SQLSTATE)?;
320    let value_span = parser.skip_keyword(Keyword::VALUE);
321    let sql_state = parse_expression(parser, false)?;
322    let mut sets = Vec::new();
323    let set_span = parser.skip_keyword(Keyword::SET);
324    if set_span.is_some() {
325        loop {
326            let v = match &parser.token {
327                Token::Ident(_, Keyword::CLASS_ORIGIN) => {
328                    SignalConditionInformationName::ClassOrigin(parser.consume())
329                }
330                Token::Ident(_, Keyword::SUBCLASS_ORIGIN) => {
331                    SignalConditionInformationName::SubclassOrigin(parser.consume())
332                }
333                Token::Ident(_, Keyword::MESSAGE_TEXT) => {
334                    SignalConditionInformationName::MessageText(parser.consume())
335                }
336                Token::Ident(_, Keyword::MYSQL_ERRNO) => {
337                    SignalConditionInformationName::MysqlErrno(parser.consume())
338                }
339                Token::Ident(_, Keyword::CONSTRAINT_CATALOG) => {
340                    SignalConditionInformationName::ConstraintCatalog(parser.consume())
341                }
342                Token::Ident(_, Keyword::CONSTRAINT_SCHEMA) => {
343                    SignalConditionInformationName::ConstraintSchema(parser.consume())
344                }
345                Token::Ident(_, Keyword::CONSTRAINT_NAME) => {
346                    SignalConditionInformationName::ConstraintName(parser.consume())
347                }
348                Token::Ident(_, Keyword::CATALOG_NAME) => {
349                    SignalConditionInformationName::CatalogName(parser.consume())
350                }
351                Token::Ident(_, Keyword::SCHEMA_NAME) => {
352                    SignalConditionInformationName::SchemaName(parser.consume())
353                }
354                Token::Ident(_, Keyword::TABLE_NAME) => {
355                    SignalConditionInformationName::TableName(parser.consume())
356                }
357                Token::Ident(_, Keyword::COLUMN_NAME) => {
358                    SignalConditionInformationName::ColumnName(parser.consume())
359                }
360                Token::Ident(_, Keyword::CURSOR_NAME) => {
361                    SignalConditionInformationName::CursorName(parser.consume())
362                }
363                _ => parser.expected_failure("Condition information item name")?,
364            };
365            let eq_span = parser.consume_token(Token::Eq)?;
366            let value = parse_expression(parser, false)?;
367            sets.push((v, eq_span, value));
368            if parser.skip_token(Token::Comma).is_none() {
369                break;
370            }
371        }
372    }
373    Ok(Signal {
374        signal_span,
375        sqlstate_span,
376        value_span,
377        sql_state,
378        set_span,
379        sets,
380    })
381}
382
383/// SQL statement
384#[allow(clippy::large_enum_variant)]
385#[derive(Clone, Debug)]
386pub enum Statement<'a> {
387    CreateIndex(CreateIndex<'a>),
388    CreateTable(CreateTable<'a>),
389    CreateView(CreateView<'a>),
390    CreateTrigger(CreateTrigger<'a>),
391    CreateFunction(CreateFunction<'a>),
392    Select(Select<'a>),
393    Delete(Delete<'a>),
394    InsertReplace(InsertReplace<'a>),
395    Update(Update<'a>),
396    DropIndex(DropIndex<'a>),
397    DropTable(DropTable<'a>),
398    DropFunction(DropFunction<'a>),
399    DropProcedure(DropProcedure<'a>),
400    DropEvent(DropEvent<'a>),
401    DropDatabase(DropDatabase<'a>),
402    DropServer(DropServer<'a>),
403    DropTrigger(DropTrigger<'a>),
404    DropView(DropView<'a>),
405    Set(Set<'a>),
406    Signal(Signal<'a>),
407    AlterTable(AlterTable<'a>),
408    Block(Vec<Statement<'a>>), //TODO we should include begin and end
409    Begin(Span),
410    End(Span),
411    Commit(Span),
412    StartTransaction(Span),
413    If(If<'a>),
414    /// Invalid statement produced after recovering from parse error
415    Invalid(Span),
416    Union(Union<'a>),
417    Case(CaseStatement<'a>),
418    Copy(Copy<'a>),
419    Stdin(&'a str, Span),
420    CreateTypeEnum(CreateTypeEnum<'a>),
421    Do(Vec<Statement<'a>>),
422    TruncateTable(TruncateTable<'a>),
423    RenameTable(RenameTable<'a>),
424    WithQuery(WithQuery<'a>),
425    Return(Return<'a>),
426}
427
428impl<'a> Spanned for Statement<'a> {
429    fn span(&self) -> Span {
430        match &self {
431            Statement::CreateIndex(v) => v.span(),
432            Statement::CreateTable(v) => v.span(),
433            Statement::CreateView(v) => v.span(),
434            Statement::CreateTrigger(v) => v.span(),
435            Statement::CreateFunction(v) => v.span(),
436            Statement::Select(v) => v.span(),
437            Statement::Delete(v) => v.span(),
438            Statement::InsertReplace(v) => v.span(),
439            Statement::Update(v) => v.span(),
440            Statement::DropIndex(v) => v.span(),
441            Statement::DropTable(v) => v.span(),
442            Statement::DropFunction(v) => v.span(),
443            Statement::DropProcedure(v) => v.span(),
444            Statement::DropEvent(v) => v.span(),
445            Statement::DropDatabase(v) => v.span(),
446            Statement::DropServer(v) => v.span(),
447            Statement::DropTrigger(v) => v.span(),
448            Statement::DropView(v) => v.span(),
449            Statement::Set(v) => v.span(),
450            Statement::AlterTable(v) => v.span(),
451            Statement::Block(v) => v.opt_span().expect("Span of block"),
452            Statement::If(v) => v.span(),
453            Statement::Invalid(v) => v.span(),
454            Statement::Union(v) => v.span(),
455            Statement::Case(v) => v.span(),
456            Statement::Copy(v) => v.span(),
457            Statement::Stdin(_, s) => s.clone(),
458            Statement::Begin(s) => s.clone(),
459            Statement::End(s) => s.clone(),
460            Statement::Commit(s) => s.clone(),
461            Statement::StartTransaction(s) => s.clone(),
462            Statement::CreateTypeEnum(v) => v.span(),
463            Statement::Do(v) => v.opt_span().expect("Span of block"),
464            Statement::TruncateTable(v) => v.span(),
465            Statement::RenameTable(v) => v.span(),
466            Statement::WithQuery(v) => v.span(),
467            Statement::Return(v) => v.span(),
468            Statement::Signal(v) => v.span(),
469        }
470    }
471}
472
473impl Statement<'_> {
474    fn reads_from_stdin(&self) -> bool {
475        match self {
476            Statement::Copy(v) => v.reads_from_stdin(),
477            _ => false,
478        }
479    }
480}
481
482pub(crate) fn parse_statement<'a>(
483    parser: &mut Parser<'a, '_>,
484) -> Result<Option<Statement<'a>>, ParseError> {
485    Ok(match &parser.token {
486        Token::Ident(_, Keyword::CREATE) => Some(parse_create(parser)?),
487        Token::Ident(_, Keyword::DROP) => Some(parse_drop(parser)?),
488        Token::Ident(_, Keyword::SELECT) | Token::LParen => Some(parse_compound_query(parser)?),
489        Token::Ident(_, Keyword::DELETE) => Some(Statement::Delete(parse_delete(parser)?)),
490        Token::Ident(_, Keyword::INSERT | Keyword::REPLACE) => {
491            Some(Statement::InsertReplace(parse_insert_replace(parser)?))
492        }
493        Token::Ident(_, Keyword::UPDATE) => Some(Statement::Update(parse_update(parser)?)),
494        Token::Ident(_, Keyword::SET) => Some(Statement::Set(parse_set(parser)?)),
495        Token::Ident(_, Keyword::SIGNAL) => Some(Statement::Signal(parse_signal(parser)?)),
496        Token::Ident(_, Keyword::BEGIN) => Some(if parser.permit_compound_statements {
497            Statement::Block(parse_block(parser)?)
498        } else {
499            Statement::Begin(parse_begin(parser)?)
500        }),
501        Token::Ident(_, Keyword::END) if !parser.permit_compound_statements => {
502            Some(Statement::End(parse_end(parser)?))
503        }
504        Token::Ident(_, Keyword::START) => Some(parse_start(parser)?),
505        Token::Ident(_, Keyword::COMMIT) => Some(Statement::Commit(parse_commit(parser)?)),
506        Token::Ident(_, Keyword::IF) => Some(Statement::If(parse_if(parser)?)),
507        Token::Ident(_, Keyword::RETURN) => Some(Statement::Return(parse_return(parser)?)),
508        Token::Ident(_, Keyword::ALTER) => Some(parse_alter(parser)?),
509        Token::Ident(_, Keyword::CASE) => Some(Statement::Case(parse_case_statement(parser)?)),
510        Token::Ident(_, Keyword::COPY) => Some(Statement::Copy(parse_copy_statement(parser)?)),
511        Token::Ident(_, Keyword::DO) => Some(parse_do(parser)?),
512        Token::Ident(_, Keyword::TRUNCATE) => {
513            Some(Statement::TruncateTable(parse_truncate_table(parser)?))
514        }
515        Token::Ident(_, Keyword::RENAME) => {
516            Some(Statement::RenameTable(parse_rename_table(parser)?))
517        }
518        Token::Ident(_, Keyword::WITH) => Some(Statement::WithQuery(parse_with_query(parser)?)),
519        _ => None,
520    })
521}
522
523pub(crate) fn parse_do<'a>(parser: &mut Parser<'a, '_>) -> Result<Statement<'a>, ParseError> {
524    parser.consume_keyword(Keyword::DO)?;
525    parser.consume_token(Token::DoubleDollar)?;
526    let block = parse_block(parser)?;
527    parser.consume_token(Token::DoubleDollar)?;
528    Ok(Statement::Do(block))
529}
530
531/// When part of case statement
532#[derive(Clone, Debug)]
533pub struct WhenStatement<'a> {
534    /// Span of "WHEN"
535    pub when_span: Span,
536    /// Expression who's match yields execution `then`
537    pub when: Expression<'a>,
538    /// Span of "THEN"
539    pub then_span: Span,
540    /// Statements to execute if `when` matches
541    pub then: Vec<Statement<'a>>,
542}
543
544impl<'a> Spanned for WhenStatement<'a> {
545    fn span(&self) -> Span {
546        self.when_span
547            .join_span(&self.when)
548            .join_span(&self.then_span)
549            .join_span(&self.then)
550    }
551}
552
553/// Case statement
554#[derive(Clone, Debug)]
555pub struct CaseStatement<'a> {
556    /// Span of "CASE"
557    pub case_span: Span,
558    /// Value to match against
559    pub value: Option<Box<Expression<'a>>>,
560    /// List of whens
561    pub whens: Vec<WhenStatement<'a>>,
562    /// Span of "ELSE" and statement to execute if specified
563    pub else_: Option<(Span, Vec<Statement<'a>>)>,
564    /// Span of "END"
565    pub end_span: Span,
566}
567
568impl<'a> Spanned for CaseStatement<'a> {
569    fn span(&self) -> Span {
570        self.case_span
571            .join_span(&self.value)
572            .join_span(&self.whens)
573            .join_span(&self.else_)
574            .join_span(&self.end_span)
575    }
576}
577
578pub(crate) fn parse_case_statement<'a>(
579    parser: &mut Parser<'a, '_>,
580) -> Result<CaseStatement<'a>, ParseError> {
581    let case_span = parser.consume_keyword(Keyword::CASE)?;
582    let value = if !matches!(parser.token, Token::Ident(_, Keyword::WHEN)) {
583        Some(Box::new(parse_expression(parser, false)?))
584    } else {
585        None
586    };
587
588    let mut whens = Vec::new();
589    let mut else_ = None;
590    parser.recovered(
591        "'END'",
592        &|t| matches!(t, Token::Ident(_, Keyword::END)),
593        |parser| {
594            loop {
595                let when_span = parser.consume_keyword(Keyword::WHEN)?;
596                let when = parse_expression(parser, false)?;
597                let then_span = parser.consume_keyword(Keyword::THEN)?;
598                let mut then = Vec::new();
599                parse_statement_list(parser, &mut then)?;
600                whens.push(WhenStatement {
601                    when_span,
602                    when,
603                    then_span,
604                    then,
605                });
606                if !matches!(parser.token, Token::Ident(_, Keyword::WHEN)) {
607                    break;
608                }
609            }
610            if let Some(span) = parser.skip_keyword(Keyword::ELSE) {
611                let mut e = Vec::new();
612                parse_statement_list(parser, &mut e)?;
613                else_ = Some((span, e))
614            };
615            Ok(())
616        },
617    )?;
618    let end_span = parser.consume_keyword(Keyword::END)?;
619    Ok(CaseStatement {
620        case_span,
621        value,
622        whens,
623        else_,
624        end_span,
625    })
626}
627
628pub(crate) fn parse_copy_statement<'a>(
629    parser: &mut Parser<'a, '_>,
630) -> Result<Copy<'a>, ParseError> {
631    let copy_span = parser.consume_keyword(Keyword::COPY)?;
632    let table = parser.consume_plain_identifier()?;
633    parser.consume_token(Token::LParen)?;
634    let mut columns = Vec::new();
635    if !matches!(parser.token, Token::RParen) {
636        loop {
637            parser.recovered(
638                "')' or ','",
639                &|t| matches!(t, Token::RParen | Token::Comma),
640                |parser| {
641                    columns.push(parser.consume_plain_identifier()?);
642                    Ok(())
643                },
644            )?;
645            if matches!(parser.token, Token::RParen) {
646                break;
647            }
648            parser.consume_token(Token::Comma)?;
649        }
650    }
651    parser.consume_token(Token::RParen)?;
652    let from_span = parser.consume_keyword(Keyword::FROM)?;
653    let stdin_span = parser.consume_keyword(Keyword::STDIN)?;
654
655    Ok(Copy {
656        copy_span,
657        table,
658        columns,
659        from_span,
660        stdin_span,
661    })
662}
663
664pub(crate) fn parse_compound_query_bottom<'a>(
665    parser: &mut Parser<'a, '_>,
666) -> Result<Statement<'a>, ParseError> {
667    match &parser.token {
668        Token::LParen => {
669            let lp = parser.consume_token(Token::LParen)?;
670            let s = parser.recovered("')'", &|t| t == &Token::RParen, |parser| {
671                Ok(Some(parse_compound_query(parser)?))
672            })?;
673            parser.consume_token(Token::RParen)?;
674            Ok(s.unwrap_or(Statement::Invalid(lp)))
675        }
676        Token::Ident(_, Keyword::SELECT) => Ok(Statement::Select(parse_select(parser)?)),
677        _ => parser.expected_failure("'SELECET' or '('")?,
678    }
679}
680
681/// Type of union to perform
682#[derive(Clone, Debug)]
683pub enum UnionType {
684    All(Span),
685    Distinct(Span),
686    Default,
687}
688
689impl OptSpanned for UnionType {
690    fn opt_span(&self) -> Option<Span> {
691        match &self {
692            UnionType::All(v) => v.opt_span(),
693            UnionType::Distinct(v) => v.opt_span(),
694            UnionType::Default => None,
695        }
696    }
697}
698
699/// Right hand side of a union expression
700#[derive(Clone, Debug)]
701pub struct UnionWith<'a> {
702    /// Span of "UNION"
703    pub union_span: Span,
704    /// Type of union to perform
705    pub union_type: UnionType,
706    /// Statement to union
707    pub union_statement: Box<Statement<'a>>,
708}
709
710impl<'a> Spanned for UnionWith<'a> {
711    fn span(&self) -> Span {
712        self.union_span
713            .join_span(&self.union_type)
714            .join_span(&self.union_statement)
715    }
716}
717
718/// Union statement
719#[derive(Clone, Debug)]
720pub struct Union<'a> {
721    /// Left side of union
722    pub left: Box<Statement<'a>>,
723    /// List of things to union
724    pub with: Vec<UnionWith<'a>>,
725    /// Span of "ORDER BY", and list of ordering expressions and directions if specified
726    pub order_by: Option<(Span, Vec<(Expression<'a>, OrderFlag)>)>,
727    /// Span of "LIMIT", offset and count expressions if specified
728    pub limit: Option<(Span, Option<Expression<'a>>, Expression<'a>)>,
729}
730
731impl<'a> Spanned for Union<'a> {
732    fn span(&self) -> Span {
733        self.left
734            .join_span(&self.with)
735            .join_span(&self.order_by)
736            .join_span(&self.limit)
737    }
738}
739
740#[derive(Clone, Debug)]
741pub struct Copy<'a> {
742    pub copy_span: Span,
743    pub table: Identifier<'a>,
744    pub columns: Vec<Identifier<'a>>,
745    pub from_span: Span,
746    pub stdin_span: Span,
747}
748
749impl<'a> Spanned for Copy<'a> {
750    fn span(&self) -> Span {
751        self.copy_span
752            .join_span(&self.table)
753            .join_span(&self.columns)
754            .join_span(&self.from_span)
755            .join_span(&self.stdin_span)
756    }
757}
758
759impl<'a> Copy<'a> {
760    fn reads_from_stdin(&self) -> bool {
761        // There are COPY statements that don't read from STDIN,
762        // but we don't support them in this parser - we only support FROM STDIN.
763        true
764    }
765}
766
767pub(crate) fn parse_compound_query<'a>(
768    parser: &mut Parser<'a, '_>,
769) -> Result<Statement<'a>, ParseError> {
770    let q = parse_compound_query_bottom(parser)?;
771    if !matches!(parser.token, Token::Ident(_, Keyword::UNION)) {
772        return Ok(q);
773    };
774    let mut with = Vec::new();
775    loop {
776        let union_span = parser.consume_keyword(Keyword::UNION)?;
777        let union_type = match &parser.token {
778            Token::Ident(_, Keyword::ALL) => UnionType::All(parser.consume_keyword(Keyword::ALL)?),
779            Token::Ident(_, Keyword::DISTINCT) => {
780                UnionType::Distinct(parser.consume_keyword(Keyword::DISTINCT)?)
781            }
782            _ => UnionType::Default,
783        };
784        let union_statement = Box::new(parse_compound_query_bottom(parser)?);
785        with.push(UnionWith {
786            union_span,
787            union_type,
788            union_statement,
789        });
790        if !matches!(parser.token, Token::Ident(_, Keyword::UNION)) {
791            break;
792        }
793    }
794
795    let order_by = if let Some(span) = parser.skip_keyword(Keyword::ORDER) {
796        let span = parser.consume_keyword(Keyword::BY)?.join_span(&span);
797        let mut order = Vec::new();
798        loop {
799            let e = parse_expression(parser, false)?;
800            let f = match &parser.token {
801                Token::Ident(_, Keyword::ASC) => OrderFlag::Asc(parser.consume()),
802                Token::Ident(_, Keyword::DESC) => OrderFlag::Desc(parser.consume()),
803                _ => OrderFlag::None,
804            };
805            order.push((e, f));
806            if parser.skip_token(Token::Comma).is_none() {
807                break;
808            }
809        }
810        Some((span, order))
811    } else {
812        None
813    };
814
815    let limit = if let Some(span) = parser.skip_keyword(Keyword::LIMIT) {
816        let n = parse_expression(parser, true)?;
817        match parser.token {
818            Token::Comma => {
819                parser.consume();
820                Some((span, Some(n), parse_expression(parser, true)?))
821            }
822            Token::Ident(_, Keyword::OFFSET) => {
823                parser.consume();
824                Some((span, Some(parse_expression(parser, true)?), n))
825            }
826            _ => Some((span, None, n)),
827        }
828    } else {
829        None
830    };
831
832    Ok(Statement::Union(Union {
833        left: Box::new(q),
834        with,
835        order_by,
836        limit,
837    }))
838}
839
840pub(crate) fn parse_statements<'a>(parser: &mut Parser<'a, '_>) -> Vec<Statement<'a>> {
841    let mut ans = Vec::new();
842    loop {
843        loop {
844            match &parser.token {
845                t if t == &parser.delimiter => {
846                    parser.consume();
847                }
848                Token::Eof => return ans,
849                _ => break,
850            }
851        }
852
853        if parser.skip_keyword(Keyword::DELIMITER).is_some() {
854            let t = parser.token.clone();
855
856            if !matches!(t, Token::DoubleDollar | Token::SemiColon) {
857                parser.warn("Unknown delimiter", &parser.span.span());
858            }
859            parser.delimiter = t;
860            parser.next();
861            continue;
862        }
863
864        let stmt = match parse_statement(parser) {
865            Ok(Some(v)) => Ok(v),
866            Ok(None) => parser.expected_failure("Statement"),
867            Err(e) => Err(e),
868        };
869        let err = stmt.is_err();
870        let mut from_stdin = false;
871        if let Ok(stmt) = stmt {
872            from_stdin = stmt.reads_from_stdin();
873            ans.push(stmt);
874        }
875
876        if parser.token != parser.delimiter {
877            if !err {
878                parser.expected_error(parser.delimiter.name());
879            }
880            // We use a custom recovery here as ; is not allowed in sub expressions, it always terminates outer most statements
881            loop {
882                parser.next();
883                match &parser.token {
884                    t if t == &parser.delimiter => break,
885                    Token::Eof => return ans,
886                    _ => (),
887                }
888            }
889        }
890        if from_stdin {
891            let (s, span) = parser.read_from_stdin_and_next();
892            ans.push(Statement::Stdin(s, span));
893        } else {
894            parser
895                .consume_token(parser.delimiter.clone())
896                .expect("Delimiter");
897        }
898    }
899}