sql_type/
type_insert_replace.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::{format, vec::Vec};
14use sql_parse::{
15    issue_todo, InsertReplace, InsertReplaceFlag, InsertReplaceSetPair, InsertReplaceType,
16    OptSpanned, Spanned,
17};
18
19use crate::{
20    type_expression::{type_expression, ExpressionFlags},
21    type_select::{type_select, type_select_exprs, SelectType},
22    typer::{typer_stack, unqualified_name, ReferenceType, Typer},
23    BaseType, SelectTypeColumn, Type,
24};
25
26/// Does the insert yield an auto increment id
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum AutoIncrementId {
29    Yes,
30    No,
31    Optional,
32}
33
34pub(crate) fn type_insert_replace<'a>(
35    typer: &mut Typer<'a, '_>,
36    ior: &InsertReplace<'a>,
37) -> (AutoIncrementId, Option<SelectType<'a>>) {
38    let table = unqualified_name(typer.issues, &ior.table);
39    let columns = &ior.columns;
40
41    let (s, auto_increment) = if let Some(schema) = typer.schemas.schemas.get(table.value) {
42        if schema.view {
43            typer.err("Inserts into views not yet implemented", table);
44        }
45        let mut col_types = Vec::new();
46
47        for col in columns {
48            if let Some(schema_col) = schema.get_column(col.value) {
49                col_types.push((schema_col.type_.clone(), col.span()));
50            } else {
51                typer.err("No such column in schema", col);
52            }
53        }
54
55        if let Some(set) = &ior.set {
56            for col in &schema.columns {
57                if col.auto_increment
58                    || col.default
59                    || !col.type_.not_null
60                    || col.as_.is_some()
61                    || col.generated
62                    || set.pairs.iter().any(|v| v.column == col.identifier)
63                {
64                    continue;
65                }
66                typer.err(
67                    format!(
68                        "No value for column {} provided, but it has no default value",
69                        &col.identifier
70                    ),
71                    set,
72                );
73            }
74        } else {
75            for col in &schema.columns {
76                if col.auto_increment
77                    || col.default
78                    || !col.type_.not_null
79                    || col.as_.is_some()
80                    || col.generated
81                    || columns.contains(&col.identifier)
82                {
83                    continue;
84                }
85                typer.err(
86                    format!(
87                        "No value for column {} provided, but it has no default value",
88                        &col.identifier
89                    ),
90                    &columns.opt_span().unwrap(),
91                );
92            }
93        }
94
95        (
96            Some(col_types),
97            schema.columns.iter().any(|c| c.auto_increment),
98        )
99    } else {
100        typer.err("Unknown table", table);
101        (None, false)
102    };
103
104    if let Some(values) = &ior.values {
105        for row in &values.1 {
106            for (j, e) in row.iter().enumerate() {
107                if let Some((et, ets)) = s.as_ref().and_then(|v| v.get(j)) {
108                    let t = type_expression(typer, e, ExpressionFlags::default(), et.base());
109                    if typer.matched_type(&t, et).is_none() {
110                        typer
111                            .err(format!("Got type {}", t.t), e)
112                            .frag(format!("Expected {}", et.t), ets);
113                    } else if let Type::Args(_, args) = &t.t {
114                        for (idx, arg_type, _) in args.iter() {
115                            typer.constrain_arg(*idx, arg_type, et);
116                        }
117                    }
118                } else {
119                    type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
120                }
121            }
122            if let Some(s) = &s {
123                if s.len() != row.len() {
124                    typer
125                        .err(
126                            format!("Got {} columns", row.len()),
127                            &row.opt_span().unwrap(),
128                        )
129                        .frag(
130                            format!("Expected {}", columns.len()),
131                            &columns.opt_span().unwrap(),
132                        );
133                }
134            }
135        }
136    }
137
138    if let Some(select) = &ior.select {
139        let select = type_select(typer, select, true);
140        if let Some(s) = s {
141            for i in 0..usize::max(s.len(), select.columns.len()) {
142                match (s.get(i), select.columns.get(i)) {
143                    (Some((et, ets)), Some(t)) => {
144                        if typer.matched_type(&t.type_, et).is_none() {
145                            typer
146                                .err(format!("Got type {}", t.type_.t), &t.span)
147                                .frag(format!("Expected {}", et.t), ets);
148                        }
149                    }
150                    (None, Some(t)) => {
151                        typer.err("Column in select not in insert", &t.span);
152                    }
153                    (Some((_, ets)), None) => {
154                        typer.err("Missing column in select", ets);
155                    }
156                    (None, None) => {
157                        panic!("ICE")
158                    }
159                }
160            }
161        }
162    }
163
164    let mut guard = typer_stack(
165        typer,
166        |t| core::mem::take(&mut t.reference_types),
167        |t, v| t.reference_types = v,
168    );
169    let typer = &mut guard.typer;
170
171    if let Some(s) = typer.schemas.schemas.get(table.value) {
172        let mut columns = Vec::new();
173        for c in &s.columns {
174            columns.push((c.identifier.clone(), c.type_.clone()));
175        }
176        for v in &typer.reference_types {
177            if v.name == Some(table.clone()) {
178                typer
179                    .issues
180                    .err("Duplicate definitions", table)
181                    .frag("Already defined here", &v.span);
182            }
183        }
184        typer.reference_types.push(ReferenceType {
185            name: Some(table.clone()),
186            span: table.span(),
187            columns,
188        });
189    }
190
191    if let Some(set) = &ior.set {
192        for InsertReplaceSetPair { column, value, .. } in &set.pairs {
193            let mut cnt = 0;
194            let mut t = None;
195            for r in &typer.reference_types {
196                for c in &r.columns {
197                    if c.0 == *column {
198                        cnt += 1;
199                        t = Some(c.clone());
200                    }
201                }
202            }
203            if cnt > 1 {
204                type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
205                let mut issue = typer.issues.err("Ambiguous reference", column);
206                for r in &typer.reference_types {
207                    for c in &r.columns {
208                        if c.0 == *column {
209                            issue.frag("Defined here", &r.span);
210                        }
211                    }
212                }
213            } else if let Some(t) = t {
214                let value_type =
215                    type_expression(typer, value, ExpressionFlags::default(), t.1.base());
216                if typer.matched_type(&value_type, &t.1).is_none() {
217                    typer.err(format!("Got type {} expected {}", value_type, t.1), value);
218                } else if let Type::Args(_, args) = &value_type.t {
219                    for (idx, arg_type, _) in args.iter() {
220                        typer.constrain_arg(*idx, arg_type, &t.1);
221                    }
222                }
223            } else {
224                type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
225                typer.err("Unknown identifier", column);
226            }
227        }
228    }
229
230    if let Some(up) = &ior.on_duplicate_key_update {
231        for InsertReplaceSetPair { value, column, .. } in &up.pairs {
232            let mut cnt = 0;
233            let mut t = None;
234            for r in &typer.reference_types {
235                for c in &r.columns {
236                    if c.0 == *column {
237                        cnt += 1;
238                        t = Some(c.clone());
239                    }
240                }
241            }
242            let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
243            if cnt > 1 {
244                type_expression(typer, value, flags, BaseType::Any);
245                let mut issue = typer.issues.err("Ambiguous reference", column);
246                for r in &typer.reference_types {
247                    for c in &r.columns {
248                        if c.0 == *column {
249                            issue.frag("Defined here", &r.span);
250                        }
251                    }
252                }
253            } else if let Some(t) = t {
254                let value_type = type_expression(typer, value, flags, t.1.base());
255                if typer.matched_type(&value_type, &t.1).is_none() {
256                    typer.err(format!("Got type {} expected {}", value_type, t.1), value);
257                } else if let Type::Args(_, args) = &value_type.t {
258                    for (idx, arg_type, _) in args.iter() {
259                        typer.constrain_arg(*idx, arg_type, &t.1);
260                    }
261                }
262            } else {
263                type_expression(typer, value, flags, BaseType::Any);
264                typer.err("Unknown identifier", column);
265            }
266        }
267    }
268
269    if let Some(on_conflict) = &ior.on_conflict {
270        match &on_conflict.target {
271            sql_parse::OnConflictTarget::Column { name } => {
272                let mut t = None;
273                for r in &typer.reference_types {
274                    for c in &r.columns {
275                        if c.0 == *name {
276                            t = Some(c.clone());
277                        }
278                    }
279                }
280                if t.is_none() {
281                    typer.err("Unknown identifier", name);
282                }
283                //TODO check if there is a unique constraint on column
284            }
285            sql_parse::OnConflictTarget::OnConstraint {
286                on_constraint_span, ..
287            } => {
288                issue_todo!(typer.issues, on_constraint_span);
289            }
290            sql_parse::OnConflictTarget::None => (),
291        }
292
293        match &on_conflict.action {
294            sql_parse::OnConflictAction::DoNothing(_) => (),
295            sql_parse::OnConflictAction::DoUpdateSet { sets, where_, .. } => {
296                for (key, value) in sets {
297                    let mut cnt = 0;
298                    let mut t = None;
299                    for r in &typer.reference_types {
300                        for c in &r.columns {
301                            if c.0 == *key {
302                                cnt += 1;
303                                t = Some(c.clone());
304                            }
305                        }
306                    }
307                    let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
308                    if cnt > 1 {
309                        type_expression(typer, value, flags, BaseType::Any);
310                        let mut issue = typer.issues.err("Ambiguous reference", key);
311                        for r in &typer.reference_types {
312                            for c in &r.columns {
313                                if c.0 == *key {
314                                    issue.frag("Defined here", &r.span);
315                                }
316                            }
317                        }
318                    } else if let Some(t) = t {
319                        let value_type = type_expression(typer, value, flags, t.1.base());
320                        if typer.matched_type(&value_type, &t.1).is_none() {
321                            typer.err(format!("Got type {} expected {}", value_type, t.1), value);
322                        } else if let Type::Args(_, args) = &value_type.t {
323                            for (idx, arg_type, _) in args.iter() {
324                                typer.constrain_arg(*idx, arg_type, &t.1);
325                            }
326                        }
327                    } else {
328                        type_expression(typer, value, flags, BaseType::Any);
329                        typer.err("Unknown identifier", key);
330                    }
331                }
332                if let Some((_, where_)) = where_ {
333                    type_expression(typer, where_, ExpressionFlags::default(), BaseType::Bool);
334                }
335            }
336        }
337    }
338
339    let returning_select = match &ior.returning {
340        Some((returning_span, returning_exprs)) => {
341            let columns = type_select_exprs(typer, returning_exprs, true)
342                .into_iter()
343                .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
344                .collect();
345            Some(SelectType {
346                columns,
347                select_span: returning_span.join_span(returning_exprs),
348            })
349        }
350        None => None,
351    };
352
353    core::mem::drop(guard);
354
355    let auto_increment_id = if auto_increment && matches!(ior.type_, InsertReplaceType::Insert(_)) {
356        if ior
357            .flags
358            .iter()
359            .any(|f| matches!(f, InsertReplaceFlag::Ignore(_)))
360            || ior.on_duplicate_key_update.is_some()
361        {
362            AutoIncrementId::Optional
363        } else {
364            AutoIncrementId::Yes
365        }
366    } else {
367        AutoIncrementId::No
368    };
369
370    (auto_increment_id, returning_select)
371}