sql_parse/
statement.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::{boxed::Box, vec::Vec};
14
15use crate::{
16    alter::{parse_alter, AlterTable},
17    create::{
18        parse_create, CreateFunction, CreateIndex, CreateTable, CreateTrigger, CreateTypeEnum,
19        CreateView,
20    },
21    delete::{parse_delete, Delete},
22    drop::{
23        parse_drop, DropDatabase, DropEvent, DropFunction, DropIndex, DropProcedure, DropServer,
24        DropTable, DropTrigger, DropView,
25    },
26    expression::{parse_expression, Expression},
27    insert_replace::{parse_insert_replace, InsertReplace},
28    keywords::Keyword,
29    lexer::Token,
30    parser::{ParseError, Parser},
31    rename::parse_rename_table,
32    select::{parse_select, OrderFlag, Select},
33    span::OptSpanned,
34    truncate::{parse_truncate_table, TruncateTable},
35    update::{parse_update, Update},
36    with_query::parse_with_query,
37    Identifier, RenameTable, Span, Spanned, WithQuery,
38};
39
40#[derive(Clone, Debug)]
41pub struct Set<'a> {
42    pub set_span: Span,
43    pub values: Vec<(Identifier<'a>, Expression<'a>)>,
44}
45
46impl<'a> Spanned for Set<'a> {
47    fn span(&self) -> Span {
48        self.set_span.join_span(&self.values)
49    }
50}
51
52fn parse_set<'a>(parser: &mut Parser<'a, '_>) -> Result<Set<'a>, ParseError> {
53    let set_span = parser.consume_keyword(Keyword::SET)?;
54    let mut values = Vec::new();
55    loop {
56        let name = parser.consume_plain_identifier()?;
57        parser.consume_token(Token::Eq)?;
58        let val = parse_expression(parser, false)?;
59        values.push((name, val));
60        if parser.skip_token(Token::Comma).is_none() {
61            break;
62        }
63    }
64    Ok(Set { set_span, values })
65}
66
67fn parse_statement_list_inner<'a>(
68    parser: &mut Parser<'a, '_>,
69    out: &mut Vec<Statement<'a>>,
70) -> Result<(), ParseError> {
71    loop {
72        while parser.skip_token(Token::SemiColon).is_some() {}
73        let stdin = match parse_statement(parser)? {
74            Some(v) => {
75                let stdin = v.reads_from_stdin();
76                out.push(v);
77                stdin
78            }
79            None => break,
80        };
81        if !matches!(parser.token, Token::SemiColon) {
82            break;
83        }
84        if stdin {
85            let (s, span) = parser.read_from_stdin_and_next();
86            out.push(Statement::Stdin(s, span));
87        } else {
88            parser.consume_token(Token::SemiColon)?;
89        }
90    }
91    Ok(())
92}
93
94fn parse_statement_list<'a>(
95    parser: &mut Parser<'a, '_>,
96    out: &mut Vec<Statement<'a>>,
97) -> Result<(), ParseError> {
98    let old_delimiter = core::mem::replace(&mut parser.delimiter, Token::SemiColon);
99    let r = parse_statement_list_inner(parser, out);
100    parser.delimiter = old_delimiter;
101    r
102}
103
104fn parse_begin(parser: &mut Parser<'_, '_>) -> Result<Span, ParseError> {
105    parser.consume_keyword(Keyword::BEGIN)
106}
107
108fn parse_end(parser: &mut Parser<'_, '_>) -> Result<Span, ParseError> {
109    parser.consume_keyword(Keyword::END)
110}
111
112fn parse_start<'a>(parser: &mut Parser<'a, '_>) -> Result<Statement<'a>, ParseError> {
113    Ok(Statement::StartTransaction(parser.consume_keywords(&[
114        Keyword::START,
115        Keyword::TRANSACTION,
116    ])?))
117}
118
119fn parse_commit(parser: &mut Parser<'_, '_>) -> Result<Span, ParseError> {
120    parser.consume_keyword(Keyword::COMMIT)
121}
122
123fn parse_block<'a>(parser: &mut Parser<'a, '_>) -> Result<Vec<Statement<'a>>, ParseError> {
124    parser.consume_keyword(Keyword::BEGIN)?;
125    let mut ans = Vec::new();
126    parser.recovered(
127        "'END' | 'EXCEPTION'",
128        &|e| {
129            matches!(
130                e,
131                Token::Ident(_, Keyword::END) | Token::Ident(_, Keyword::EXCEPTION)
132            )
133        },
134        |parser| parse_statement_list(parser, &mut ans),
135    )?;
136    if let Some(_exception_span) = parser.skip_keyword(Keyword::EXCEPTION) {
137        while let Some(_when_span) = parser.skip_keyword(Keyword::WHEN) {
138            parser.consume_plain_identifier()?;
139            parser.consume_keyword(Keyword::THEN)?;
140            parse_expression(parser, true)?;
141            parser.consume_token(Token::SemiColon)?;
142        }
143    }
144    parser.consume_keyword(Keyword::END)?;
145    Ok(ans)
146}
147
148/// Condition in if statement
149#[derive(Clone, Debug)]
150pub struct IfCondition<'a> {
151    /// Span of "ELSEIF" if specified
152    pub elseif_span: Option<Span>,
153    /// Expression that must be true for `then` to be executed
154    pub search_condition: Expression<'a>,
155    /// Span of "THEN"
156    pub then_span: Span,
157    /// List of statement to be executed if `search_condition` is true
158    pub then: Vec<Statement<'a>>,
159}
160
161impl<'a> Spanned for IfCondition<'a> {
162    fn span(&self) -> Span {
163        self.then_span
164            .join_span(&self.elseif_span)
165            .join_span(&self.search_condition)
166            .join_span(&self.then_span)
167            .join_span(&self.then)
168    }
169}
170
171/// If statement
172#[derive(Clone, Debug)]
173pub struct If<'a> {
174    /// Span of "IF"
175    pub if_span: Span,
176    // List of if a then v parts
177    pub conditions: Vec<IfCondition<'a>>,
178    /// Span of "ELSE" and else Statement if specified
179    pub else_: Option<(Span, Vec<Statement<'a>>)>,
180    /// Span of "ENDIF"
181    pub endif_span: Span,
182}
183
184impl<'a> Spanned for If<'a> {
185    fn span(&self) -> Span {
186        self.if_span
187            .join_span(&self.conditions)
188            .join_span(&self.else_)
189            .join_span(&self.endif_span)
190    }
191}
192
193fn parse_if<'a>(parser: &mut Parser<'a, '_>) -> Result<If<'a>, ParseError> {
194    let if_span = parser.consume_keyword(Keyword::IF)?;
195    let mut conditions = Vec::new();
196    let mut else_ = None;
197    parser.recovered(
198        "'END'",
199        &|e| matches!(e, Token::Ident(_, Keyword::END)),
200        |parser| {
201            let search_condition = parse_expression(parser, false)?;
202            let then_span = parser.consume_keyword(Keyword::THEN)?;
203            let mut then = Vec::new();
204            parse_statement_list(parser, &mut then)?;
205            conditions.push(IfCondition {
206                elseif_span: None,
207                search_condition,
208                then_span,
209                then,
210            });
211            while let Some(elseif_span) = parser.skip_keyword(Keyword::ELSEIF) {
212                let search_condition = parse_expression(parser, false)?;
213                let then_span = parser.consume_keyword(Keyword::THEN)?;
214                let mut then = Vec::new();
215                parse_statement_list(parser, &mut then)?;
216                conditions.push(IfCondition {
217                    elseif_span: Some(elseif_span),
218                    search_condition,
219                    then_span,
220                    then,
221                })
222            }
223            if let Some(else_span) = parser.skip_keyword(Keyword::ELSE) {
224                let mut o = Vec::new();
225                parse_statement_list(parser, &mut o)?;
226                else_ = Some((else_span, o));
227            }
228            Ok(())
229        },
230    )?;
231    let endif_span = parser.consume_keywords(&[Keyword::END, Keyword::IF])?;
232    Ok(If {
233        if_span,
234        conditions,
235        else_,
236        endif_span,
237    })
238}
239
240/// Return statement
241#[derive(Clone, Debug)]
242pub struct Return<'a> {
243    /// Span of "Return"
244    pub return_span: Span,
245    pub expr: Expression<'a>,
246}
247
248impl<'a> Spanned for Return<'a> {
249    fn span(&self) -> Span {
250        self.return_span.join_span(&self.expr)
251    }
252}
253
254fn parse_return<'a>(parser: &mut Parser<'a, '_>) -> Result<Return<'a>, ParseError> {
255    let return_span = parser.consume_keyword(Keyword::RETURN)?;
256    let expr = parse_expression(parser, false)?;
257    Ok(Return { return_span, expr })
258}
259
260#[derive(Clone, Debug)]
261pub enum SignalConditionInformationName {
262    ClassOrigin(Span),
263    SubclassOrigin(Span),
264    MessageText(Span),
265    MysqlErrno(Span),
266    ConstraintCatalog(Span),
267    ConstraintSchema(Span),
268    ConstraintName(Span),
269    CatalogName(Span),
270    SchemaName(Span),
271    TableName(Span),
272    ColumnName(Span),
273    CursorName(Span),
274}
275
276impl Spanned for SignalConditionInformationName {
277    fn span(&self) -> Span {
278        match self {
279            SignalConditionInformationName::ClassOrigin(span) => span.clone(),
280            SignalConditionInformationName::SubclassOrigin(span) => span.clone(),
281            SignalConditionInformationName::MessageText(span) => span.clone(),
282            SignalConditionInformationName::MysqlErrno(span) => span.clone(),
283            SignalConditionInformationName::ConstraintCatalog(span) => span.clone(),
284            SignalConditionInformationName::ConstraintSchema(span) => span.clone(),
285            SignalConditionInformationName::ConstraintName(span) => span.clone(),
286            SignalConditionInformationName::CatalogName(span) => span.clone(),
287            SignalConditionInformationName::SchemaName(span) => span.clone(),
288            SignalConditionInformationName::TableName(span) => span.clone(),
289            SignalConditionInformationName::ColumnName(span) => span.clone(),
290            SignalConditionInformationName::CursorName(span) => span.clone(),
291        }
292    }
293}
294
295/// Return statement
296#[derive(Clone, Debug)]
297pub struct Signal<'a> {
298    pub signal_span: Span,
299    pub sqlstate_span: Span,
300    pub value_span: Option<Span>,
301    pub sql_state: Expression<'a>,
302    pub set_span: Option<Span>,
303    pub sets: Vec<(SignalConditionInformationName, Span, Expression<'a>)>,
304}
305
306impl<'a> Spanned for Signal<'a> {
307    fn span(&self) -> Span {
308        self.signal_span
309            .join_span(&self.sqlstate_span)
310            .join_span(&self.value_span)
311            .join_span(&self.sql_state)
312            .join_span(&self.set_span)
313            .join_span(&self.sets)
314    }
315}
316
317fn parse_signal<'a>(parser: &mut Parser<'a, '_>) -> Result<Signal<'a>, ParseError> {
318    let signal_span = parser.consume_keyword(Keyword::SIGNAL)?;
319    let sqlstate_span = parser.consume_keyword(Keyword::SQLSTATE)?;
320    let value_span = parser.skip_keyword(Keyword::VALUE);
321    let sql_state = parse_expression(parser, false)?;
322    let mut sets = Vec::new();
323    let set_span = parser.skip_keyword(Keyword::SET);
324    if set_span.is_some() {
325        loop {
326            let v = match &parser.token {
327                Token::Ident(_, Keyword::CLASS_ORIGIN) => {
328                    SignalConditionInformationName::ClassOrigin(parser.consume())
329                }
330                Token::Ident(_, Keyword::SUBCLASS_ORIGIN) => {
331                    SignalConditionInformationName::SubclassOrigin(parser.consume())
332                }
333                Token::Ident(_, Keyword::MESSAGE_TEXT) => {
334                    SignalConditionInformationName::MessageText(parser.consume())
335                }
336                Token::Ident(_, Keyword::MYSQL_ERRNO) => {
337                    SignalConditionInformationName::MysqlErrno(parser.consume())
338                }
339                Token::Ident(_, Keyword::CONSTRAINT_CATALOG) => {
340                    SignalConditionInformationName::ConstraintCatalog(parser.consume())
341                }
342                Token::Ident(_, Keyword::CONSTRAINT_SCHEMA) => {
343                    SignalConditionInformationName::ConstraintSchema(parser.consume())
344                }
345                Token::Ident(_, Keyword::CONSTRAINT_NAME) => {
346                    SignalConditionInformationName::ConstraintName(parser.consume())
347                }
348                Token::Ident(_, Keyword::CATALOG_NAME) => {
349                    SignalConditionInformationName::CatalogName(parser.consume())
350                }
351                Token::Ident(_, Keyword::SCHEMA_NAME) => {
352                    SignalConditionInformationName::SchemaName(parser.consume())
353                }
354                Token::Ident(_, Keyword::TABLE_NAME) => {
355                    SignalConditionInformationName::TableName(parser.consume())
356                }
357                Token::Ident(_, Keyword::COLUMN_NAME) => {
358                    SignalConditionInformationName::ColumnName(parser.consume())
359                }
360                Token::Ident(_, Keyword::CURSOR_NAME) => {
361                    SignalConditionInformationName::CursorName(parser.consume())
362                }
363                _ => parser.expected_failure("Condition information item name")?,
364            };
365            let eq_span = parser.consume_token(Token::Eq)?;
366            let value = parse_expression(parser, false)?;
367            sets.push((v, eq_span, value));
368            if parser.skip_token(Token::Comma).is_none() {
369                break;
370            }
371        }
372    }
373    Ok(Signal {
374        signal_span,
375        sqlstate_span,
376        value_span,
377        sql_state,
378        set_span,
379        sets,
380    })
381}
382
383/// SQL statement
384#[derive(Clone, Debug)]
385pub enum Statement<'a> {
386    CreateIndex(CreateIndex<'a>),
387    CreateTable(CreateTable<'a>),
388    CreateView(CreateView<'a>),
389    CreateTrigger(CreateTrigger<'a>),
390    CreateFunction(CreateFunction<'a>),
391    Select(Select<'a>),
392    Delete(Delete<'a>),
393    InsertReplace(InsertReplace<'a>),
394    Update(Update<'a>),
395    DropIndex(DropIndex<'a>),
396    DropTable(DropTable<'a>),
397    DropFunction(DropFunction<'a>),
398    DropProcedure(DropProcedure<'a>),
399    DropEvent(DropEvent<'a>),
400    DropDatabase(DropDatabase<'a>),
401    DropServer(DropServer<'a>),
402    DropTrigger(DropTrigger<'a>),
403    DropView(DropView<'a>),
404    Set(Set<'a>),
405    Signal(Signal<'a>),
406    AlterTable(AlterTable<'a>),
407    Block(Vec<Statement<'a>>), //TODO we should include begin and end
408    Begin(Span),
409    End(Span),
410    Commit(Span),
411    StartTransaction(Span),
412    If(If<'a>),
413    /// Invalid statement produced after recovering from parse error
414    Invalid(Span),
415    Union(Union<'a>),
416    Case(CaseStatement<'a>),
417    Copy(Copy<'a>),
418    Stdin(&'a str, Span),
419    CreateTypeEnum(CreateTypeEnum<'a>),
420    Do(Vec<Statement<'a>>),
421    TruncateTable(TruncateTable<'a>),
422    RenameTable(RenameTable<'a>),
423    WithQuery(WithQuery<'a>),
424    Return(Return<'a>),
425}
426
427impl<'a> Spanned for Statement<'a> {
428    fn span(&self) -> Span {
429        match &self {
430            Statement::CreateIndex(v) => v.span(),
431            Statement::CreateTable(v) => v.span(),
432            Statement::CreateView(v) => v.span(),
433            Statement::CreateTrigger(v) => v.span(),
434            Statement::CreateFunction(v) => v.span(),
435            Statement::Select(v) => v.span(),
436            Statement::Delete(v) => v.span(),
437            Statement::InsertReplace(v) => v.span(),
438            Statement::Update(v) => v.span(),
439            Statement::DropIndex(v) => v.span(),
440            Statement::DropTable(v) => v.span(),
441            Statement::DropFunction(v) => v.span(),
442            Statement::DropProcedure(v) => v.span(),
443            Statement::DropEvent(v) => v.span(),
444            Statement::DropDatabase(v) => v.span(),
445            Statement::DropServer(v) => v.span(),
446            Statement::DropTrigger(v) => v.span(),
447            Statement::DropView(v) => v.span(),
448            Statement::Set(v) => v.span(),
449            Statement::AlterTable(v) => v.span(),
450            Statement::Block(v) => v.opt_span().expect("Span of block"),
451            Statement::If(v) => v.span(),
452            Statement::Invalid(v) => v.span(),
453            Statement::Union(v) => v.span(),
454            Statement::Case(v) => v.span(),
455            Statement::Copy(v) => v.span(),
456            Statement::Stdin(_, s) => s.clone(),
457            Statement::Begin(s) => s.clone(),
458            Statement::End(s) => s.clone(),
459            Statement::Commit(s) => s.clone(),
460            Statement::StartTransaction(s) => s.clone(),
461            Statement::CreateTypeEnum(v) => v.span(),
462            Statement::Do(v) => v.opt_span().expect("Span of block"),
463            Statement::TruncateTable(v) => v.span(),
464            Statement::RenameTable(v) => v.span(),
465            Statement::WithQuery(v) => v.span(),
466            Statement::Return(v) => v.span(),
467            Statement::Signal(v) => v.span(),
468        }
469    }
470}
471
472impl Statement<'_> {
473    fn reads_from_stdin(&self) -> bool {
474        match self {
475            Statement::Copy(v) => v.reads_from_stdin(),
476            _ => false,
477        }
478    }
479}
480
481pub(crate) fn parse_statement<'a>(
482    parser: &mut Parser<'a, '_>,
483) -> Result<Option<Statement<'a>>, ParseError> {
484    Ok(match &parser.token {
485        Token::Ident(_, Keyword::CREATE) => Some(parse_create(parser)?),
486        Token::Ident(_, Keyword::DROP) => Some(parse_drop(parser)?),
487        Token::Ident(_, Keyword::SELECT) | Token::LParen => Some(parse_compound_query(parser)?),
488        Token::Ident(_, Keyword::DELETE) => Some(Statement::Delete(parse_delete(parser)?)),
489        Token::Ident(_, Keyword::INSERT | Keyword::REPLACE) => {
490            Some(Statement::InsertReplace(parse_insert_replace(parser)?))
491        }
492        Token::Ident(_, Keyword::UPDATE) => Some(Statement::Update(parse_update(parser)?)),
493        Token::Ident(_, Keyword::SET) => Some(Statement::Set(parse_set(parser)?)),
494        Token::Ident(_, Keyword::SIGNAL) => Some(Statement::Signal(parse_signal(parser)?)),
495        Token::Ident(_, Keyword::BEGIN) => Some(if parser.permit_compound_statements {
496            Statement::Block(parse_block(parser)?)
497        } else {
498            Statement::Begin(parse_begin(parser)?)
499        }),
500        Token::Ident(_, Keyword::END) if !parser.permit_compound_statements => {
501            Some(Statement::End(parse_end(parser)?))
502        }
503        Token::Ident(_, Keyword::START) => Some(parse_start(parser)?),
504        Token::Ident(_, Keyword::COMMIT) => Some(Statement::Commit(parse_commit(parser)?)),
505        Token::Ident(_, Keyword::IF) => Some(Statement::If(parse_if(parser)?)),
506        Token::Ident(_, Keyword::RETURN) => Some(Statement::Return(parse_return(parser)?)),
507        Token::Ident(_, Keyword::ALTER) => Some(parse_alter(parser)?),
508        Token::Ident(_, Keyword::CASE) => Some(Statement::Case(parse_case_statement(parser)?)),
509        Token::Ident(_, Keyword::COPY) => Some(Statement::Copy(parse_copy_statement(parser)?)),
510        Token::Ident(_, Keyword::DO) => Some(parse_do(parser)?),
511        Token::Ident(_, Keyword::TRUNCATE) => {
512            Some(Statement::TruncateTable(parse_truncate_table(parser)?))
513        }
514        Token::Ident(_, Keyword::RENAME) => {
515            Some(Statement::RenameTable(parse_rename_table(parser)?))
516        }
517        Token::Ident(_, Keyword::WITH) => Some(Statement::WithQuery(parse_with_query(parser)?)),
518        _ => None,
519    })
520}
521
522pub(crate) fn parse_do<'a>(parser: &mut Parser<'a, '_>) -> Result<Statement<'a>, ParseError> {
523    parser.consume_keyword(Keyword::DO)?;
524    parser.consume_token(Token::DoubleDollar)?;
525    let block = parse_block(parser)?;
526    parser.consume_token(Token::DoubleDollar)?;
527    Ok(Statement::Do(block))
528}
529
530/// When part of case statement
531#[derive(Clone, Debug)]
532pub struct WhenStatement<'a> {
533    /// Span of "WHEN"
534    pub when_span: Span,
535    /// Expression who's match yields execution `then`
536    pub when: Expression<'a>,
537    /// Span of "THEN"
538    pub then_span: Span,
539    /// Statements to execute if `when` matches
540    pub then: Vec<Statement<'a>>,
541}
542
543impl<'a> Spanned for WhenStatement<'a> {
544    fn span(&self) -> Span {
545        self.when_span
546            .join_span(&self.when)
547            .join_span(&self.then_span)
548            .join_span(&self.then)
549    }
550}
551
552/// Case statement
553#[derive(Clone, Debug)]
554pub struct CaseStatement<'a> {
555    /// Span of "CASE"
556    pub case_span: Span,
557    /// Value to match against
558    pub value: Option<Box<Expression<'a>>>,
559    /// List of whens
560    pub whens: Vec<WhenStatement<'a>>,
561    /// Span of "ELSE" and statement to execute if specified
562    pub else_: Option<(Span, Vec<Statement<'a>>)>,
563    /// Span of "END"
564    pub end_span: Span,
565}
566
567impl<'a> Spanned for CaseStatement<'a> {
568    fn span(&self) -> Span {
569        self.case_span
570            .join_span(&self.value)
571            .join_span(&self.whens)
572            .join_span(&self.else_)
573            .join_span(&self.end_span)
574    }
575}
576
577pub(crate) fn parse_case_statement<'a>(
578    parser: &mut Parser<'a, '_>,
579) -> Result<CaseStatement<'a>, ParseError> {
580    let case_span = parser.consume_keyword(Keyword::CASE)?;
581    let value = if !matches!(parser.token, Token::Ident(_, Keyword::WHEN)) {
582        Some(Box::new(parse_expression(parser, false)?))
583    } else {
584        None
585    };
586
587    let mut whens = Vec::new();
588    let mut else_ = None;
589    parser.recovered(
590        "'END'",
591        &|t| matches!(t, Token::Ident(_, Keyword::END)),
592        |parser| {
593            loop {
594                let when_span = parser.consume_keyword(Keyword::WHEN)?;
595                let when = parse_expression(parser, false)?;
596                let then_span = parser.consume_keyword(Keyword::THEN)?;
597                let mut then = Vec::new();
598                parse_statement_list(parser, &mut then)?;
599                whens.push(WhenStatement {
600                    when_span,
601                    when,
602                    then_span,
603                    then,
604                });
605                if !matches!(parser.token, Token::Ident(_, Keyword::WHEN)) {
606                    break;
607                }
608            }
609            if let Some(span) = parser.skip_keyword(Keyword::ELSE) {
610                let mut e = Vec::new();
611                parse_statement_list(parser, &mut e)?;
612                else_ = Some((span, e))
613            };
614            Ok(())
615        },
616    )?;
617    let end_span = parser.consume_keyword(Keyword::END)?;
618    Ok(CaseStatement {
619        case_span,
620        value,
621        whens,
622        else_,
623        end_span,
624    })
625}
626
627pub(crate) fn parse_copy_statement<'a>(
628    parser: &mut Parser<'a, '_>,
629) -> Result<Copy<'a>, ParseError> {
630    let copy_span = parser.consume_keyword(Keyword::COPY)?;
631    let table = parser.consume_plain_identifier()?;
632    parser.consume_token(Token::LParen)?;
633    let mut columns = Vec::new();
634    if !matches!(parser.token, Token::RParen) {
635        loop {
636            parser.recovered(
637                "')' or ','",
638                &|t| matches!(t, Token::RParen | Token::Comma),
639                |parser| {
640                    columns.push(parser.consume_plain_identifier()?);
641                    Ok(())
642                },
643            )?;
644            if matches!(parser.token, Token::RParen) {
645                break;
646            }
647            parser.consume_token(Token::Comma)?;
648        }
649    }
650    parser.consume_token(Token::RParen)?;
651    let from_span = parser.consume_keyword(Keyword::FROM)?;
652    let stdin_span = parser.consume_keyword(Keyword::STDIN)?;
653
654    Ok(Copy {
655        copy_span,
656        table,
657        columns,
658        from_span,
659        stdin_span,
660    })
661}
662
663pub(crate) fn parse_compound_query_bottom<'a>(
664    parser: &mut Parser<'a, '_>,
665) -> Result<Statement<'a>, ParseError> {
666    match &parser.token {
667        Token::LParen => {
668            let lp = parser.consume_token(Token::LParen)?;
669            let s = parser.recovered("')'", &|t| t == &Token::RParen, |parser| {
670                Ok(Some(parse_compound_query(parser)?))
671            })?;
672            parser.consume_token(Token::RParen)?;
673            Ok(s.unwrap_or(Statement::Invalid(lp)))
674        }
675        Token::Ident(_, Keyword::SELECT) => Ok(Statement::Select(parse_select(parser)?)),
676        _ => parser.expected_failure("'SELECET' or '('")?,
677    }
678}
679
680/// Type of union to perform
681#[derive(Clone, Debug)]
682pub enum UnionType {
683    All(Span),
684    Distinct(Span),
685    Default,
686}
687
688impl OptSpanned for UnionType {
689    fn opt_span(&self) -> Option<Span> {
690        match &self {
691            UnionType::All(v) => v.opt_span(),
692            UnionType::Distinct(v) => v.opt_span(),
693            UnionType::Default => None,
694        }
695    }
696}
697
698/// Right hand side of a union expression
699#[derive(Clone, Debug)]
700pub struct UnionWith<'a> {
701    /// Span of "UNION"
702    pub union_span: Span,
703    /// Type of union to perform
704    pub union_type: UnionType,
705    /// Statement to union
706    pub union_statement: Box<Statement<'a>>,
707}
708
709impl<'a> Spanned for UnionWith<'a> {
710    fn span(&self) -> Span {
711        self.union_span
712            .join_span(&self.union_type)
713            .join_span(&self.union_statement)
714    }
715}
716
717/// Union statement
718#[derive(Clone, Debug)]
719pub struct Union<'a> {
720    /// Left side of union
721    pub left: Box<Statement<'a>>,
722    /// List of things to union
723    pub with: Vec<UnionWith<'a>>,
724    /// Span of "ORDER BY", and list of ordering expressions and directions if specified
725    pub order_by: Option<(Span, Vec<(Expression<'a>, OrderFlag)>)>,
726    /// Span of "LIMIT", offset and count expressions if specified
727    pub limit: Option<(Span, Option<Expression<'a>>, Expression<'a>)>,
728}
729
730impl<'a> Spanned for Union<'a> {
731    fn span(&self) -> Span {
732        self.left
733            .join_span(&self.with)
734            .join_span(&self.order_by)
735            .join_span(&self.limit)
736    }
737}
738
739#[derive(Clone, Debug)]
740pub struct Copy<'a> {
741    pub copy_span: Span,
742    pub table: Identifier<'a>,
743    pub columns: Vec<Identifier<'a>>,
744    pub from_span: Span,
745    pub stdin_span: Span,
746}
747
748impl<'a> Spanned for Copy<'a> {
749    fn span(&self) -> Span {
750        self.copy_span
751            .join_span(&self.table)
752            .join_span(&self.columns)
753            .join_span(&self.from_span)
754            .join_span(&self.stdin_span)
755    }
756}
757
758impl<'a> Copy<'a> {
759    fn reads_from_stdin(&self) -> bool {
760        // There are COPY statements that don't read from STDIN,
761        // but we don't support them in this parser - we only support FROM STDIN.
762        true
763    }
764}
765
766pub(crate) fn parse_compound_query<'a>(
767    parser: &mut Parser<'a, '_>,
768) -> Result<Statement<'a>, ParseError> {
769    let q = parse_compound_query_bottom(parser)?;
770    if !matches!(parser.token, Token::Ident(_, Keyword::UNION)) {
771        return Ok(q);
772    };
773    let mut with = Vec::new();
774    loop {
775        let union_span = parser.consume_keyword(Keyword::UNION)?;
776        let union_type = match &parser.token {
777            Token::Ident(_, Keyword::ALL) => UnionType::All(parser.consume_keyword(Keyword::ALL)?),
778            Token::Ident(_, Keyword::DISTINCT) => {
779                UnionType::Distinct(parser.consume_keyword(Keyword::DISTINCT)?)
780            }
781            _ => UnionType::Default,
782        };
783        let union_statement = Box::new(parse_compound_query_bottom(parser)?);
784        with.push(UnionWith {
785            union_span,
786            union_type,
787            union_statement,
788        });
789        if !matches!(parser.token, Token::Ident(_, Keyword::UNION)) {
790            break;
791        }
792    }
793
794    let order_by = if let Some(span) = parser.skip_keyword(Keyword::ORDER) {
795        let span = parser.consume_keyword(Keyword::BY)?.join_span(&span);
796        let mut order = Vec::new();
797        loop {
798            let e = parse_expression(parser, false)?;
799            let f = match &parser.token {
800                Token::Ident(_, Keyword::ASC) => OrderFlag::Asc(parser.consume()),
801                Token::Ident(_, Keyword::DESC) => OrderFlag::Desc(parser.consume()),
802                _ => OrderFlag::None,
803            };
804            order.push((e, f));
805            if parser.skip_token(Token::Comma).is_none() {
806                break;
807            }
808        }
809        Some((span, order))
810    } else {
811        None
812    };
813
814    let limit = if let Some(span) = parser.skip_keyword(Keyword::LIMIT) {
815        let n = parse_expression(parser, true)?;
816        match parser.token {
817            Token::Comma => {
818                parser.consume();
819                Some((span, Some(n), parse_expression(parser, true)?))
820            }
821            Token::Ident(_, Keyword::OFFSET) => {
822                parser.consume();
823                Some((span, Some(parse_expression(parser, true)?), n))
824            }
825            _ => Some((span, None, n)),
826        }
827    } else {
828        None
829    };
830
831    Ok(Statement::Union(Union {
832        left: Box::new(q),
833        with,
834        order_by,
835        limit,
836    }))
837}
838
839pub(crate) fn parse_statements<'a>(parser: &mut Parser<'a, '_>) -> Vec<Statement<'a>> {
840    let mut ans = Vec::new();
841    loop {
842        loop {
843            match &parser.token {
844                t if t == &parser.delimiter => {
845                    parser.consume();
846                }
847                Token::Eof => return ans,
848                _ => break,
849            }
850        }
851
852        if parser.skip_keyword(Keyword::DELIMITER).is_some() {
853            let t = parser.token.clone();
854
855            if !matches!(t, Token::DoubleDollar | Token::SemiColon) {
856                parser.warn("Unknown delimiter", &parser.span.span());
857            }
858            parser.delimiter = t;
859            parser.next();
860            continue;
861        }
862
863        let stmt = match parse_statement(parser) {
864            Ok(Some(v)) => Ok(v),
865            Ok(None) => parser.expected_failure("Statement"),
866            Err(e) => Err(e),
867        };
868        let err = stmt.is_err();
869        let mut from_stdin = false;
870        if let Ok(stmt) = stmt {
871            from_stdin = stmt.reads_from_stdin();
872            ans.push(stmt);
873        }
874
875        if parser.token != parser.delimiter {
876            if !err {
877                parser.expected_error(parser.delimiter.name());
878            }
879            // We use a custom recovery here as ; is not allowed in sub expressions, it always terminates outer most statements
880            loop {
881                parser.next();
882                match &parser.token {
883                    t if t == &parser.delimiter => break,
884                    Token::Eof => return ans,
885                    _ => (),
886                }
887            }
888        }
889        if from_stdin {
890            let (s, span) = parser.read_from_stdin_and_next();
891            ans.push(Statement::Stdin(s, span));
892        } else {
893            parser
894                .consume_token(parser.delimiter.clone())
895                .expect("Delimiter");
896        }
897    }
898}