sqlparser_mysql/dms/
insert.rs

1use std::fmt;
2use std::str;
3
4use nom::bytes::complete::{tag, tag_no_case};
5use nom::character::complete::{multispace0, multispace1};
6use nom::combinator::opt;
7use nom::multi::many1;
8use nom::sequence::{delimited, preceded, tuple};
9use nom::IResult;
10
11use base::column::Column;
12use base::error::ParseSQLError;
13use base::table::Table;
14use base::{CommonParser, DisplayUtil, FieldValueExpression, Literal};
15
16#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
17pub struct InsertStatement {
18    pub table: Table,
19    pub fields: Option<Vec<Column>>,
20    pub data: Vec<Vec<Literal>>,
21    pub ignore: bool,
22    pub on_duplicate: Option<Vec<(Column, FieldValueExpression)>>,
23}
24
25impl InsertStatement {
26    // Parse rule for a SQL insert query.
27    // TODO(malte): support REPLACE, nested selection, DEFAULT VALUES
28    pub fn parse(i: &str) -> IResult<&str, InsertStatement, ParseSQLError<&str>> {
29        let (
30            remaining_input,
31            (_, ignore_res, _, _, _, table, _, fields, _, _, data, on_duplicate, _, _),
32        ) = tuple((
33            tag_no_case("INSERT"),
34            opt(preceded(multispace1, tag_no_case("IGNORE"))),
35            multispace1,
36            tag_no_case("INTO"),
37            multispace1,
38            Table::schema_table_reference,
39            multispace0,
40            opt(Self::fields),
41            tag_no_case("VALUES"),
42            multispace0,
43            many1(Self::data),
44            opt(Self::on_duplicate),
45            multispace0,
46            CommonParser::statement_terminator,
47        ))(i)?;
48        assert!(table.alias.is_none());
49        let ignore = ignore_res.is_some();
50
51        Ok((
52            remaining_input,
53            InsertStatement {
54                table,
55                fields,
56                data,
57                ignore,
58                on_duplicate,
59            },
60        ))
61    }
62
63    fn fields(i: &str) -> IResult<&str, Vec<Column>, ParseSQLError<&str>> {
64        delimited(
65            preceded(tag("("), multispace0),
66            Column::field_list,
67            delimited(multispace0, tag(")"), multispace1),
68        )(i)
69    }
70
71    fn data(i: &str) -> IResult<&str, Vec<Literal>, ParseSQLError<&str>> {
72        delimited(
73            tag("("),
74            Literal::value_list,
75            preceded(tag(")"), opt(CommonParser::ws_sep_comma)),
76        )(i)
77    }
78
79    pub fn on_duplicate(
80        i: &str,
81    ) -> IResult<&str, Vec<(Column, FieldValueExpression)>, ParseSQLError<&str>> {
82        preceded(
83            multispace0,
84            preceded(
85                tuple((
86                    tag_no_case("ON"),
87                    multispace1,
88                    tag_no_case("DUPLICATE"),
89                    multispace1,
90                    tag_no_case("KEY"),
91                    multispace1,
92                    tag_no_case("UPDATE"),
93                )),
94                preceded(multispace1, FieldValueExpression::assignment_expr_list),
95            ),
96        )(i)
97    }
98}
99
100impl fmt::Display for InsertStatement {
101    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
102        write!(
103            f,
104            "INSERT INTO {}",
105            DisplayUtil::escape_if_keyword(&self.table.name)
106        )?;
107        if let Some(ref fields) = self.fields {
108            write!(
109                f,
110                " ({})",
111                fields
112                    .iter()
113                    .map(|col| col.name.to_owned())
114                    .collect::<Vec<_>>()
115                    .join(", ")
116            )?;
117        }
118        write!(
119            f,
120            " VALUES {}",
121            self.data
122                .iter()
123                .map(|data| format!(
124                    "({})",
125                    data.iter()
126                        .map(|l| l.to_string())
127                        .collect::<Vec<_>>()
128                        .join(", ")
129                ))
130                .collect::<Vec<_>>()
131                .join(", ")
132        )
133    }
134}