sql_type/
type_insert_replace.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::{format, vec::Vec};
14use sql_parse::{
15    Identifier, InsertReplace, InsertReplaceFlag, InsertReplaceSetPair, InsertReplaceType,
16    OptSpanned, Spanned, issue_todo,
17};
18
19use crate::{
20    BaseType, SelectTypeColumn, Type,
21    type_expression::{ExpressionFlags, type_expression},
22    type_select::{SelectType, type_select, type_select_exprs},
23    typer::{ReferenceType, Typer, typer_stack, unqualified_name},
24};
25
26/// Does the insert yield an auto increment id
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum AutoIncrementId {
29    Yes,
30    No,
31    Optional,
32}
33
34pub(crate) fn type_insert_replace<'a>(
35    typer: &mut Typer<'a, '_>,
36    ior: &InsertReplace<'a>,
37) -> (AutoIncrementId, Option<SelectType<'a>>) {
38    let table = unqualified_name(typer.issues, &ior.table);
39    let columns = &ior.columns;
40
41    let (s, auto_increment) = if let Some(schema) = typer.schemas.schemas.get(table.value) {
42        if schema.view {
43            typer.err("Inserts into views not yet implemented", table);
44        }
45        let mut col_types = Vec::new();
46
47        for col in columns {
48            if let Some(schema_col) = schema.get_column(col.value) {
49                col_types.push((schema_col.type_.clone(), col.span()));
50            } else {
51                typer.err("No such column in schema", col);
52            }
53        }
54
55        if let Some(set) = &ior.set {
56            for col in &schema.columns {
57                if col.auto_increment
58                    || col.default
59                    || !col.type_.not_null
60                    || col.as_.is_some()
61                    || col.generated
62                    || set.pairs.iter().any(|v| v.column == col.identifier)
63                {
64                    continue;
65                }
66                typer.err(
67                    format!(
68                        "No value for column {} provided, but it has no default value",
69                        &col.identifier
70                    ),
71                    set,
72                );
73            }
74        } else {
75            for col in &schema.columns {
76                if col.auto_increment
77                    || col.default
78                    || !col.type_.not_null
79                    || col.as_.is_some()
80                    || col.generated
81                    || columns.contains(&col.identifier)
82                {
83                    continue;
84                }
85                typer.err(
86                    format!(
87                        "No value for column {} provided, but it has no default value",
88                        &col.identifier
89                    ),
90                    &columns.opt_span().unwrap(),
91                );
92            }
93        }
94
95        (
96            Some(col_types),
97            schema.columns.iter().any(|c| c.auto_increment),
98        )
99    } else {
100        typer.err("Unknown table", table);
101        (None, false)
102    };
103
104    if let Some(values) = &ior.values {
105        for row in &values.1 {
106            for (j, e) in row.iter().enumerate() {
107                if let Some((et, ets)) = s.as_ref().and_then(|v| v.get(j)) {
108                    let t = type_expression(typer, e, ExpressionFlags::default(), et.base());
109                    if typer.matched_type(&t, et).is_none() {
110                        typer
111                            .err(format!("Got type {}", t.t), e)
112                            .frag(format!("Expected {}", et.t), ets);
113                    } else if let Type::Args(_, args) = &t.t {
114                        for (idx, arg_type, _) in args.iter() {
115                            typer.constrain_arg(*idx, arg_type, et);
116                        }
117                    }
118                } else {
119                    type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
120                }
121            }
122            if let Some(s) = &s
123                && s.len() != row.len()
124            {
125                typer
126                    .err(
127                        format!("Got {} columns", row.len()),
128                        &row.opt_span().unwrap(),
129                    )
130                    .frag(
131                        format!("Expected {}", columns.len()),
132                        &columns.opt_span().unwrap(),
133                    );
134            }
135        }
136    }
137
138    if let Some(select) = &ior.select {
139        let select = type_select(typer, select, true);
140        if let Some(s) = s {
141            for i in 0..usize::max(s.len(), select.columns.len()) {
142                match (s.get(i), select.columns.get(i)) {
143                    (Some((et, ets)), Some(t)) => {
144                        if typer.matched_type(&t.type_, et).is_none() {
145                            typer
146                                .err(format!("Got type {}", t.type_.t), &t.span)
147                                .frag(format!("Expected {}", et.t), ets);
148                        }
149                    }
150                    (None, Some(t)) => {
151                        typer.err("Column in select not in insert", &t.span);
152                    }
153                    (Some((_, ets)), None) => {
154                        typer.err("Missing column in select", ets);
155                    }
156                    (None, None) => {
157                        panic!("ICE")
158                    }
159                }
160            }
161        }
162    }
163
164    let mut guard = typer_stack(
165        typer,
166        |t| core::mem::take(&mut t.reference_types),
167        |t, v| t.reference_types = v,
168    );
169    let typer = &mut guard.typer;
170
171    if let Some(s) = typer.schemas.schemas.get(table.value) {
172        let mut columns = Vec::new();
173        for c in &s.columns {
174            columns.push((c.identifier.clone(), c.type_.clone()));
175        }
176        for v in &typer.reference_types {
177            if v.name == Some(table.clone()) {
178                typer
179                    .issues
180                    .err("Duplicate definitions", table)
181                    .frag("Already defined here", &v.span);
182            }
183        }
184        typer.reference_types.push(ReferenceType {
185            name: Some(table.clone()),
186            span: table.span(),
187            columns,
188        });
189    }
190
191    if let Some(set) = &ior.set {
192        for InsertReplaceSetPair { column, value, .. } in &set.pairs {
193            let mut cnt = 0;
194            let mut t = None;
195            for r in &typer.reference_types {
196                for c in &r.columns {
197                    if c.0 == *column {
198                        cnt += 1;
199                        t = Some(c.clone());
200                    }
201                }
202            }
203            if cnt > 1 {
204                type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
205                let mut issue = typer.issues.err("Ambiguous reference", column);
206                for r in &typer.reference_types {
207                    for c in &r.columns {
208                        if c.0 == *column {
209                            issue.frag("Defined here", &r.span);
210                        }
211                    }
212                }
213            } else if let Some(t) = t {
214                let value_type =
215                    type_expression(typer, value, ExpressionFlags::default(), t.1.base());
216                if typer.matched_type(&value_type, &t.1).is_none() {
217                    typer.err(format!("Got type {} expected {}", value_type, t.1), value);
218                } else if let Type::Args(_, args) = &value_type.t {
219                    for (idx, arg_type, _) in args.iter() {
220                        typer.constrain_arg(*idx, arg_type, &t.1);
221                    }
222                }
223            } else {
224                type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
225                typer.err("Unknown identifier", column);
226            }
227        }
228    }
229
230    if let Some(up) = &ior.on_duplicate_key_update {
231        for InsertReplaceSetPair { value, column, .. } in &up.pairs {
232            let mut cnt = 0;
233            let mut t = None;
234            for r in &typer.reference_types {
235                for c in &r.columns {
236                    if c.0 == *column {
237                        cnt += 1;
238                        t = Some(c.clone());
239                    }
240                }
241            }
242            let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
243            if cnt > 1 {
244                type_expression(typer, value, flags, BaseType::Any);
245                let mut issue = typer.issues.err("Ambiguous reference", column);
246                for r in &typer.reference_types {
247                    for c in &r.columns {
248                        if c.0 == *column {
249                            issue.frag("Defined here", &r.span);
250                        }
251                    }
252                }
253            } else if let Some(t) = t {
254                let value_type = type_expression(typer, value, flags, t.1.base());
255                if typer.matched_type(&value_type, &t.1).is_none() {
256                    typer.err(format!("Got type {} expected {}", value_type, t.1), value);
257                } else if let Type::Args(_, args) = &value_type.t {
258                    for (idx, arg_type, _) in args.iter() {
259                        typer.constrain_arg(*idx, arg_type, &t.1);
260                    }
261                }
262            } else {
263                type_expression(typer, value, flags, BaseType::Any);
264                typer.err("Unknown identifier", column);
265            }
266        }
267    }
268
269    if let Some(on_conflict) = &ior.on_conflict {
270        match &on_conflict.target {
271            sql_parse::OnConflictTarget::Columns { names } => {
272                for name in names {
273                    let mut t = None;
274                    for r in &typer.reference_types {
275                        for c in &r.columns {
276                            if c.0 == *name {
277                                t = Some(c.clone());
278                            }
279                        }
280                    }
281                    if t.is_none() {
282                        typer.err("Unknown identifier", name);
283                    }
284                }
285                //TODO check if there is a unique constraint on column
286            }
287            sql_parse::OnConflictTarget::OnConstraint {
288                on_constraint_span, ..
289            } => {
290                issue_todo!(typer.issues, on_constraint_span);
291            }
292            sql_parse::OnConflictTarget::None => (),
293        }
294
295        match &on_conflict.action {
296            sql_parse::OnConflictAction::DoNothing(_) => (),
297            sql_parse::OnConflictAction::DoUpdateSet {
298                sets,
299                where_,
300                do_update_set_span,
301            } => {
302                let mut excluded_columns = Vec::new();
303                if let Some(schema) = typer.schemas.schemas.get(table.value)
304                    && !schema.view
305                {
306                    for col in &ior.columns {
307                        if let Some(schema_col) = schema.get_column(col.value) {
308                            excluded_columns
309                                .push((schema_col.identifier.clone(), schema_col.type_.clone()));
310                        }
311                    }
312                }
313
314                typer.reference_types.push(ReferenceType {
315                    name: Some(Identifier::new("EXCLUDED", do_update_set_span.clone())),
316                    span: do_update_set_span.span(),
317                    columns: excluded_columns,
318                });
319
320                for (key, value) in sets {
321                    let mut cnt = 0;
322                    let mut t = None;
323                    for r in &typer.reference_types {
324                        if let Some(name) = &r.name
325                            && name.value == "EXCLUDED"
326                        {
327                            continue;
328                        }
329                        for c in &r.columns {
330                            if c.0 == *key {
331                                cnt += 1;
332                                t = Some(c.clone());
333                            }
334                        }
335                    }
336                    let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
337                    if cnt > 1 {
338                        type_expression(typer, value, flags, BaseType::Any);
339                        let mut issue = typer.issues.err("Ambiguous reference", key);
340                        for r in &typer.reference_types {
341                            for c in &r.columns {
342                                if c.0 == *key {
343                                    issue.frag("Defined here", &r.span);
344                                }
345                            }
346                        }
347                    } else if let Some(t) = t {
348                        let value_type = type_expression(typer, value, flags, t.1.base());
349                        if typer.matched_type(&value_type, &t.1).is_none() {
350                            typer.err(format!("Got type {} expected {}", value_type, t.1), value);
351                        } else if let Type::Args(_, args) = &value_type.t {
352                            for (idx, arg_type, _) in args.iter() {
353                                typer.constrain_arg(*idx, arg_type, &t.1);
354                            }
355                        }
356                    } else {
357                        type_expression(typer, value, flags, BaseType::Any);
358                        typer.err("Unknown identifier", key);
359                    }
360                }
361                if let Some((_, where_)) = where_ {
362                    type_expression(typer, where_, ExpressionFlags::default(), BaseType::Bool);
363                }
364            }
365        }
366    }
367
368    let returning_select = match &ior.returning {
369        Some((returning_span, returning_exprs)) => {
370            let columns = type_select_exprs(typer, returning_exprs, true)
371                .into_iter()
372                .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
373                .collect();
374            Some(SelectType {
375                columns,
376                select_span: returning_span.join_span(returning_exprs),
377            })
378        }
379        None => None,
380    };
381
382    core::mem::drop(guard);
383
384    let auto_increment_id = if auto_increment && matches!(ior.type_, InsertReplaceType::Insert(_)) {
385        if ior
386            .flags
387            .iter()
388            .any(|f| matches!(f, InsertReplaceFlag::Ignore(_)))
389            || ior.on_duplicate_key_update.is_some()
390        {
391            AutoIncrementId::Optional
392        } else {
393            AutoIncrementId::Yes
394        }
395    } else {
396        AutoIncrementId::No
397    };
398
399    (auto_increment_id, returning_select)
400}