sql_type/
type_insert_replace.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::{format, vec::Vec};
14use sql_parse::{
15    InsertReplace, InsertReplaceFlag, InsertReplaceSetPair, InsertReplaceType, OptSpanned, Spanned,
16    issue_todo,
17};
18
19use crate::{
20    BaseType, SelectTypeColumn, Type,
21    type_expression::{ExpressionFlags, type_expression},
22    type_select::{SelectType, type_select, type_select_exprs},
23    typer::{ReferenceType, Typer, typer_stack, unqualified_name},
24};
25
26/// Does the insert yield an auto increment id
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum AutoIncrementId {
29    Yes,
30    No,
31    Optional,
32}
33
34pub(crate) fn type_insert_replace<'a>(
35    typer: &mut Typer<'a, '_>,
36    ior: &InsertReplace<'a>,
37) -> (AutoIncrementId, Option<SelectType<'a>>) {
38    let table = unqualified_name(typer.issues, &ior.table);
39    let columns = &ior.columns;
40
41    let (s, auto_increment) = if let Some(schema) = typer.schemas.schemas.get(table.value) {
42        if schema.view {
43            typer.err("Inserts into views not yet implemented", table);
44        }
45        let mut col_types = Vec::new();
46
47        for col in columns {
48            if let Some(schema_col) = schema.get_column(col.value) {
49                col_types.push((schema_col.type_.clone(), col.span()));
50            } else {
51                typer.err("No such column in schema", col);
52            }
53        }
54
55        if let Some(set) = &ior.set {
56            for col in &schema.columns {
57                if col.auto_increment
58                    || col.default
59                    || !col.type_.not_null
60                    || col.as_.is_some()
61                    || col.generated
62                    || set.pairs.iter().any(|v| v.column == col.identifier)
63                {
64                    continue;
65                }
66                typer.err(
67                    format!(
68                        "No value for column {} provided, but it has no default value",
69                        &col.identifier
70                    ),
71                    set,
72                );
73            }
74        } else {
75            for col in &schema.columns {
76                if col.auto_increment
77                    || col.default
78                    || !col.type_.not_null
79                    || col.as_.is_some()
80                    || col.generated
81                    || columns.contains(&col.identifier)
82                {
83                    continue;
84                }
85                typer.err(
86                    format!(
87                        "No value for column {} provided, but it has no default value",
88                        &col.identifier
89                    ),
90                    &columns.opt_span().unwrap(),
91                );
92            }
93        }
94
95        (
96            Some(col_types),
97            schema.columns.iter().any(|c| c.auto_increment),
98        )
99    } else {
100        typer.err("Unknown table", table);
101        (None, false)
102    };
103
104    if let Some(values) = &ior.values {
105        for row in &values.1 {
106            for (j, e) in row.iter().enumerate() {
107                if let Some((et, ets)) = s.as_ref().and_then(|v| v.get(j)) {
108                    let t = type_expression(typer, e, ExpressionFlags::default(), et.base());
109                    if typer.matched_type(&t, et).is_none() {
110                        typer
111                            .err(format!("Got type {}", t.t), e)
112                            .frag(format!("Expected {}", et.t), ets);
113                    } else if let Type::Args(_, args) = &t.t {
114                        for (idx, arg_type, _) in args.iter() {
115                            typer.constrain_arg(*idx, arg_type, et);
116                        }
117                    }
118                } else {
119                    type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
120                }
121            }
122            if let Some(s) = &s
123                && s.len() != row.len() {
124                    typer
125                        .err(
126                            format!("Got {} columns", row.len()),
127                            &row.opt_span().unwrap(),
128                        )
129                        .frag(
130                            format!("Expected {}", columns.len()),
131                            &columns.opt_span().unwrap(),
132                        );
133                }
134        }
135    }
136
137    if let Some(select) = &ior.select {
138        let select = type_select(typer, select, true);
139        if let Some(s) = s {
140            for i in 0..usize::max(s.len(), select.columns.len()) {
141                match (s.get(i), select.columns.get(i)) {
142                    (Some((et, ets)), Some(t)) => {
143                        if typer.matched_type(&t.type_, et).is_none() {
144                            typer
145                                .err(format!("Got type {}", t.type_.t), &t.span)
146                                .frag(format!("Expected {}", et.t), ets);
147                        }
148                    }
149                    (None, Some(t)) => {
150                        typer.err("Column in select not in insert", &t.span);
151                    }
152                    (Some((_, ets)), None) => {
153                        typer.err("Missing column in select", ets);
154                    }
155                    (None, None) => {
156                        panic!("ICE")
157                    }
158                }
159            }
160        }
161    }
162
163    let mut guard = typer_stack(
164        typer,
165        |t| core::mem::take(&mut t.reference_types),
166        |t, v| t.reference_types = v,
167    );
168    let typer = &mut guard.typer;
169
170    if let Some(s) = typer.schemas.schemas.get(table.value) {
171        let mut columns = Vec::new();
172        for c in &s.columns {
173            columns.push((c.identifier.clone(), c.type_.clone()));
174        }
175        for v in &typer.reference_types {
176            if v.name == Some(table.clone()) {
177                typer
178                    .issues
179                    .err("Duplicate definitions", table)
180                    .frag("Already defined here", &v.span);
181            }
182        }
183        typer.reference_types.push(ReferenceType {
184            name: Some(table.clone()),
185            span: table.span(),
186            columns,
187        });
188    }
189
190    if let Some(set) = &ior.set {
191        for InsertReplaceSetPair { column, value, .. } in &set.pairs {
192            let mut cnt = 0;
193            let mut t = None;
194            for r in &typer.reference_types {
195                for c in &r.columns {
196                    if c.0 == *column {
197                        cnt += 1;
198                        t = Some(c.clone());
199                    }
200                }
201            }
202            if cnt > 1 {
203                type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
204                let mut issue = typer.issues.err("Ambiguous reference", column);
205                for r in &typer.reference_types {
206                    for c in &r.columns {
207                        if c.0 == *column {
208                            issue.frag("Defined here", &r.span);
209                        }
210                    }
211                }
212            } else if let Some(t) = t {
213                let value_type =
214                    type_expression(typer, value, ExpressionFlags::default(), t.1.base());
215                if typer.matched_type(&value_type, &t.1).is_none() {
216                    typer.err(format!("Got type {} expected {}", value_type, t.1), value);
217                } else if let Type::Args(_, args) = &value_type.t {
218                    for (idx, arg_type, _) in args.iter() {
219                        typer.constrain_arg(*idx, arg_type, &t.1);
220                    }
221                }
222            } else {
223                type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
224                typer.err("Unknown identifier", column);
225            }
226        }
227    }
228
229    if let Some(up) = &ior.on_duplicate_key_update {
230        for InsertReplaceSetPair { value, column, .. } in &up.pairs {
231            let mut cnt = 0;
232            let mut t = None;
233            for r in &typer.reference_types {
234                for c in &r.columns {
235                    if c.0 == *column {
236                        cnt += 1;
237                        t = Some(c.clone());
238                    }
239                }
240            }
241            let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
242            if cnt > 1 {
243                type_expression(typer, value, flags, BaseType::Any);
244                let mut issue = typer.issues.err("Ambiguous reference", column);
245                for r in &typer.reference_types {
246                    for c in &r.columns {
247                        if c.0 == *column {
248                            issue.frag("Defined here", &r.span);
249                        }
250                    }
251                }
252            } else if let Some(t) = t {
253                let value_type = type_expression(typer, value, flags, t.1.base());
254                if typer.matched_type(&value_type, &t.1).is_none() {
255                    typer.err(format!("Got type {} expected {}", value_type, t.1), value);
256                } else if let Type::Args(_, args) = &value_type.t {
257                    for (idx, arg_type, _) in args.iter() {
258                        typer.constrain_arg(*idx, arg_type, &t.1);
259                    }
260                }
261            } else {
262                type_expression(typer, value, flags, BaseType::Any);
263                typer.err("Unknown identifier", column);
264            }
265        }
266    }
267
268    if let Some(on_conflict) = &ior.on_conflict {
269        match &on_conflict.target {
270            sql_parse::OnConflictTarget::Column { name } => {
271                let mut t = None;
272                for r in &typer.reference_types {
273                    for c in &r.columns {
274                        if c.0 == *name {
275                            t = Some(c.clone());
276                        }
277                    }
278                }
279                if t.is_none() {
280                    typer.err("Unknown identifier", name);
281                }
282                //TODO check if there is a unique constraint on column
283            }
284            sql_parse::OnConflictTarget::OnConstraint {
285                on_constraint_span, ..
286            } => {
287                issue_todo!(typer.issues, on_constraint_span);
288            }
289            sql_parse::OnConflictTarget::None => (),
290        }
291
292        match &on_conflict.action {
293            sql_parse::OnConflictAction::DoNothing(_) => (),
294            sql_parse::OnConflictAction::DoUpdateSet { sets, where_, .. } => {
295                for (key, value) in sets {
296                    let mut cnt = 0;
297                    let mut t = None;
298                    for r in &typer.reference_types {
299                        for c in &r.columns {
300                            if c.0 == *key {
301                                cnt += 1;
302                                t = Some(c.clone());
303                            }
304                        }
305                    }
306                    let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
307                    if cnt > 1 {
308                        type_expression(typer, value, flags, BaseType::Any);
309                        let mut issue = typer.issues.err("Ambiguous reference", key);
310                        for r in &typer.reference_types {
311                            for c in &r.columns {
312                                if c.0 == *key {
313                                    issue.frag("Defined here", &r.span);
314                                }
315                            }
316                        }
317                    } else if let Some(t) = t {
318                        let value_type = type_expression(typer, value, flags, t.1.base());
319                        if typer.matched_type(&value_type, &t.1).is_none() {
320                            typer.err(format!("Got type {} expected {}", value_type, t.1), value);
321                        } else if let Type::Args(_, args) = &value_type.t {
322                            for (idx, arg_type, _) in args.iter() {
323                                typer.constrain_arg(*idx, arg_type, &t.1);
324                            }
325                        }
326                    } else {
327                        type_expression(typer, value, flags, BaseType::Any);
328                        typer.err("Unknown identifier", key);
329                    }
330                }
331                if let Some((_, where_)) = where_ {
332                    type_expression(typer, where_, ExpressionFlags::default(), BaseType::Bool);
333                }
334            }
335        }
336    }
337
338    let returning_select = match &ior.returning {
339        Some((returning_span, returning_exprs)) => {
340            let columns = type_select_exprs(typer, returning_exprs, true)
341                .into_iter()
342                .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
343                .collect();
344            Some(SelectType {
345                columns,
346                select_span: returning_span.join_span(returning_exprs),
347            })
348        }
349        None => None,
350    };
351
352    core::mem::drop(guard);
353
354    let auto_increment_id = if auto_increment && matches!(ior.type_, InsertReplaceType::Insert(_)) {
355        if ior
356            .flags
357            .iter()
358            .any(|f| matches!(f, InsertReplaceFlag::Ignore(_)))
359            || ior.on_duplicate_key_update.is_some()
360        {
361            AutoIncrementId::Optional
362        } else {
363            AutoIncrementId::Yes
364        }
365    } else {
366        AutoIncrementId::No
367    };
368
369    (auto_increment_id, returning_select)
370}