sqlparser_mysql/dms/
update.rs

1use std::{fmt, str};
2
3use nom::bytes::complete::tag_no_case;
4use nom::character::complete::{multispace0, multispace1};
5use nom::combinator::opt;
6use nom::sequence::tuple;
7use nom::IResult;
8
9use base::column::Column;
10use base::condition::ConditionExpression;
11use base::error::ParseSQLError;
12use base::table::Table;
13use base::{CommonParser, DisplayUtil, FieldValueExpression};
14
15#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
16pub struct UpdateStatement {
17    pub table: Table,
18    pub fields: Vec<(Column, FieldValueExpression)>,
19    pub where_clause: Option<ConditionExpression>,
20}
21
22impl UpdateStatement {
23    pub fn parse(i: &str) -> IResult<&str, UpdateStatement, ParseSQLError<&str>> {
24        let (remaining_input, (_, _, table, _, _, _, fields, _, where_clause, _)) = tuple((
25            tag_no_case("UPDATE"),
26            multispace1,
27            Table::table_reference,
28            multispace1,
29            tag_no_case("SET"),
30            multispace1,
31            FieldValueExpression::assignment_expr_list,
32            multispace0,
33            opt(ConditionExpression::parse),
34            CommonParser::statement_terminator,
35        ))(i)?;
36        Ok((
37            remaining_input,
38            UpdateStatement {
39                table,
40                fields,
41                where_clause,
42            },
43        ))
44    }
45}
46
47impl fmt::Display for UpdateStatement {
48    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
49        write!(
50            f,
51            "UPDATE {} ",
52            DisplayUtil::escape_if_keyword(&self.table.name)
53        )?;
54        assert!(!self.fields.is_empty());
55        write!(
56            f,
57            "SET {}",
58            self.fields
59                .iter()
60                .map(|(col, literal)| format!("{} = {}", col, literal))
61                .collect::<Vec<_>>()
62                .join(", ")
63        )?;
64        if let Some(ref where_clause) = self.where_clause {
65            write!(f, " WHERE ")?;
66            write!(f, "{}", where_clause)?;
67        }
68        Ok(())
69    }
70}