sql_parse/
update.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::vec;
14use alloc::vec::Vec;
15
16use crate::{
17    expression::{parse_expression, Expression},
18    keywords::Keyword,
19    lexer::Token,
20    parser::{ParseError, Parser},
21    select::{parse_select_expr, parse_table_reference, TableReference},
22    span::OptSpanned,
23    Identifier, SelectExpr, Span, Spanned,
24};
25
26/// Flags specified after "UPDATE"
27#[derive(Clone, Debug)]
28pub enum UpdateFlag {
29    LowPriority(Span),
30    Ignore(Span),
31}
32
33impl Spanned for UpdateFlag {
34    fn span(&self) -> Span {
35        match &self {
36            UpdateFlag::LowPriority(v) => v.span(),
37            UpdateFlag::Ignore(v) => v.span(),
38        }
39    }
40}
41
42/// Representation of replace Statement
43///
44/// ```
45/// # use sql_parse::{SQLDialect, SQLArguments, ParseOptions, parse_statement, Update, Statement, Issues};
46/// # let options = ParseOptions::new().dialect(SQLDialect::MariaDB);
47/// #
48/// let sql = "UPDATE tab1, tab2 SET tab1.column1 = value1, tab1.column2 = value2 WHERE tab1.id = tab2.id";
49/// let mut issues = Issues::new(sql);
50/// let stmt = parse_statement(sql, &mut issues, &options);
51///
52/// # assert!(issues.is_ok());
53/// let u: Update = match stmt {
54///     Some(Statement::Update(u)) => u,
55///     _ => panic!("We should get an update statement")
56/// };
57///
58/// println!("{:#?}", u.where_.unwrap())
59/// ```
60#[derive(Clone, Debug)]
61pub struct Update<'a> {
62    /// Span of "UPDATE"
63    pub update_span: Span,
64    /// Flags specified after "UPDATE"
65    pub flags: Vec<UpdateFlag>,
66    /// List of tables to update
67    pub tables: Vec<TableReference<'a>>,
68    /// Span of "SET"
69    pub set_span: Span,
70    /// List of key,value pairs to set
71    pub set: Vec<(Vec<Identifier<'a>>, Expression<'a>)>,
72    /// Where expression and span of "WHERE" if specified
73    pub where_: Option<(Expression<'a>, Span)>,
74    /// Span of "RETURNING" and select expressions after "RETURNING", if "RETURNING" is present
75    pub returning: Option<(Span, Vec<SelectExpr<'a>>)>,
76}
77
78impl<'a> Spanned for Update<'a> {
79    fn span(&self) -> Span {
80        let mut set_span = None;
81        for (a, b) in &self.set {
82            set_span = set_span.opt_join_span(a).opt_join_span(b)
83        }
84
85        self.update_span
86            .join_span(&self.flags)
87            .join_span(&self.tables)
88            .join_span(&self.set_span)
89            .join_span(&set_span)
90            .join_span(&self.where_)
91            .join_span(&self.returning)
92    }
93}
94
95pub(crate) fn parse_update<'a>(parser: &mut Parser<'a, '_>) -> Result<Update<'a>, ParseError> {
96    let update_span = parser.consume_keyword(Keyword::UPDATE)?;
97    let mut flags = Vec::new();
98
99    loop {
100        match &parser.token {
101            Token::Ident(_, Keyword::LOW_PRIORITY) => flags.push(UpdateFlag::LowPriority(
102                parser.consume_keyword(Keyword::LOW_PRIORITY)?,
103            )),
104            Token::Ident(_, Keyword::IGNORE) => {
105                flags.push(UpdateFlag::Ignore(parser.consume_keyword(Keyword::IGNORE)?))
106            }
107            _ => break,
108        }
109    }
110
111    let mut tables = Vec::new();
112    loop {
113        tables.push(parse_table_reference(parser)?);
114        if parser.skip_token(Token::Comma).is_none() {
115            break;
116        }
117    }
118
119    let set_span = parser.consume_keyword(Keyword::SET)?;
120    let mut set = Vec::new();
121    loop {
122        let mut col = vec![parser.consume_plain_identifier()?];
123        while parser.skip_token(Token::Period).is_some() {
124            col.push(parser.consume_plain_identifier()?);
125        }
126        parser.consume_token(Token::Eq)?;
127        let val = parse_expression(parser, false)?;
128        set.push((col, val));
129        if parser.skip_token(Token::Comma).is_none() {
130            break;
131        }
132    }
133
134    let where_ = if let Some(span) = parser.skip_keyword(Keyword::WHERE) {
135        Some((parse_expression(parser, false)?, span))
136    } else {
137        None
138    };
139
140    let returning = if let Some(returning_span) = parser.skip_keyword(Keyword::RETURNING) {
141        let mut returning_exprs = Vec::new();
142        loop {
143            returning_exprs.push(parse_select_expr(parser)?);
144            if parser.skip_token(Token::Comma).is_none() {
145                break;
146            }
147        }
148        if !parser.options.dialect.is_postgresql() {
149            parser.err("Only support by postgesql", &returning_span);
150        }
151        Some((returning_span, returning_exprs))
152    } else {
153        None
154    };
155
156    Ok(Update {
157        flags,
158        update_span,
159        tables,
160        set_span,
161        set,
162        where_,
163        returning,
164    })
165}
166
167// UPDATE [LOW_PRIORITY] [IGNORE] table_references
168// SET col1={expr1|DEFAULT} [, col2={expr2|DEFAULT}] ...
169// [WHERE where_condition]