sql_parse/
update.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::vec;
14use alloc::vec::Vec;
15
16use crate::{
17    expression::{parse_expression, Expression},
18    keywords::Keyword,
19    lexer::Token,
20    parser::{ParseError, Parser},
21    select::{parse_table_reference, TableReference},
22    span::OptSpanned,
23    Identifier, Span, Spanned,
24};
25
26/// Flags specified after "UPDATE"
27#[derive(Clone, Debug)]
28pub enum UpdateFlag {
29    LowPriority(Span),
30    Ignore(Span),
31}
32
33impl Spanned for UpdateFlag {
34    fn span(&self) -> Span {
35        match &self {
36            UpdateFlag::LowPriority(v) => v.span(),
37            UpdateFlag::Ignore(v) => v.span(),
38        }
39    }
40}
41
42/// Representation of replace Statement
43///
44/// ```
45/// # use sql_parse::{SQLDialect, SQLArguments, ParseOptions, parse_statement, Update, Statement, Issues};
46/// # let options = ParseOptions::new().dialect(SQLDialect::MariaDB);
47/// #
48/// let sql = "UPDATE tab1, tab2 SET tab1.column1 = value1, tab1.column2 = value2 WHERE tab1.id = tab2.id";
49/// let mut issues = Issues::new(sql);
50/// let stmt = parse_statement(sql, &mut issues, &options);
51///
52/// # assert!(issues.is_ok());
53/// let u: Update = match stmt {
54///     Some(Statement::Update(u)) => u,
55///     _ => panic!("We should get an update statement")
56/// };
57///
58/// println!("{:#?}", u.where_.unwrap())
59/// ```
60#[derive(Clone, Debug)]
61pub struct Update<'a> {
62    /// Span of "UPDATE"
63    pub update_span: Span,
64    /// Flags specified after "UPDATE"
65    pub flags: Vec<UpdateFlag>,
66    /// List of tables to update
67    pub tables: Vec<TableReference<'a>>,
68    /// Span of "SET"
69    pub set_span: Span,
70    /// List of key,value pairs to set
71    pub set: Vec<(Vec<Identifier<'a>>, Expression<'a>)>,
72    /// Where expression and span of "WHERE" if specified
73    pub where_: Option<(Expression<'a>, Span)>,
74}
75
76impl<'a> Spanned for Update<'a> {
77    fn span(&self) -> Span {
78        let mut set_span = None;
79        for (a, b) in &self.set {
80            set_span = set_span.opt_join_span(a).opt_join_span(b)
81        }
82
83        self.update_span
84            .join_span(&self.flags)
85            .join_span(&self.tables)
86            .join_span(&self.set_span)
87            .join_span(&set_span)
88            .join_span(&self.where_)
89    }
90}
91
92pub(crate) fn parse_update<'a>(parser: &mut Parser<'a, '_>) -> Result<Update<'a>, ParseError> {
93    let update_span = parser.consume_keyword(Keyword::UPDATE)?;
94    let mut flags = Vec::new();
95
96    loop {
97        match &parser.token {
98            Token::Ident(_, Keyword::LOW_PRIORITY) => flags.push(UpdateFlag::LowPriority(
99                parser.consume_keyword(Keyword::LOW_PRIORITY)?,
100            )),
101            Token::Ident(_, Keyword::IGNORE) => {
102                flags.push(UpdateFlag::Ignore(parser.consume_keyword(Keyword::IGNORE)?))
103            }
104            _ => break,
105        }
106    }
107
108    let mut tables = Vec::new();
109    loop {
110        tables.push(parse_table_reference(parser)?);
111        if parser.skip_token(Token::Comma).is_none() {
112            break;
113        }
114    }
115
116    let set_span = parser.consume_keyword(Keyword::SET)?;
117    let mut set = Vec::new();
118    loop {
119        let mut col = vec![parser.consume_plain_identifier()?];
120        while parser.skip_token(Token::Period).is_some() {
121            col.push(parser.consume_plain_identifier()?);
122        }
123        parser.consume_token(Token::Eq)?;
124        let val = parse_expression(parser, false)?;
125        set.push((col, val));
126        if parser.skip_token(Token::Comma).is_none() {
127            break;
128        }
129    }
130
131    let where_ = if let Some(span) = parser.skip_keyword(Keyword::WHERE) {
132        Some((parse_expression(parser, false)?, span))
133    } else {
134        None
135    };
136
137    Ok(Update {
138        flags,
139        update_span,
140        tables,
141        set_span,
142        set,
143        where_,
144    })
145}
146
147// UPDATE [LOW_PRIORITY] [IGNORE] table_references
148// SET col1={expr1|DEFAULT} [, col2={expr2|DEFAULT}] ...
149// [WHERE where_condition]