sql_type/
type_insert_replace.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::{format, vec::Vec};
14use sql_parse::{
15    issue_todo, InsertReplace, InsertReplaceFlag, InsertReplaceSetPair, InsertReplaceType, Spanned,
16};
17
18use crate::{
19    type_expression::{type_expression, ExpressionFlags},
20    type_select::{type_select, type_select_exprs, SelectType},
21    typer::{typer_stack, unqualified_name, ReferenceType, Typer},
22    BaseType, SelectTypeColumn, Type,
23};
24
25/// Does the insert yield an auto increment id
26#[derive(Clone, Debug, PartialEq, Eq)]
27pub enum AutoIncrementId {
28    Yes,
29    No,
30    Optional,
31}
32
33pub(crate) fn type_insert_replace<'a>(
34    typer: &mut Typer<'a, '_>,
35    ior: &InsertReplace<'a>,
36) -> (AutoIncrementId, Option<SelectType<'a>>) {
37    let table = unqualified_name(typer.issues, &ior.table);
38    let columns = &ior.columns;
39
40    let (s, auto_increment) = if let Some(schema) = typer.schemas.schemas.get(table.value) {
41        if schema.view {
42            typer.err("Inserts into views not yet implemented", table);
43        }
44        let mut col_types = Vec::new();
45
46        for col in columns {
47            if let Some(schema_col) = schema.get_column(col.value) {
48                col_types.push((schema_col.type_.clone(), col.span()));
49            } else {
50                typer.err("No such column in schema", col);
51            }
52        }
53        (
54            Some(col_types),
55            schema.columns.iter().any(|c| c.auto_increment),
56        )
57    } else {
58        typer.err("Unknown table", table);
59        (None, false)
60    };
61
62    if let Some(values) = &ior.values {
63        for row in &values.1 {
64            for (j, e) in row.iter().enumerate() {
65                if let Some((et, ets)) = s.as_ref().and_then(|v| v.get(j)) {
66                    let t = type_expression(typer, e, ExpressionFlags::default(), et.base());
67                    if typer.matched_type(&t, et).is_none() {
68                        typer
69                            .err(format!("Got type {}", t.t), e)
70                            .frag(format!("Expected {}", et.t), ets);
71                    } else if let Type::Args(_, args) = &t.t {
72                        for (idx, arg_type, _) in args.iter() {
73                            typer.constrain_arg(*idx, arg_type, et);
74                        }
75                    }
76                } else {
77                    type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
78                }
79            }
80        }
81    }
82
83    if let Some(select) = &ior.select {
84        let select = type_select(typer, select, true);
85        if let Some(s) = s {
86            for i in 0..usize::max(s.len(), select.columns.len()) {
87                match (s.get(i), select.columns.get(i)) {
88                    (Some((et, ets)), Some(t)) => {
89                        if typer.matched_type(&t.type_, et).is_none() {
90                            typer
91                                .err(format!("Got type {}", t.type_.t), &t.span)
92                                .frag(format!("Expected {}", et.t), ets);
93                        }
94                    }
95                    (None, Some(t)) => {
96                        typer.err("Column in select not in insert", &t.span);
97                    }
98                    (Some((_, ets)), None) => {
99                        typer.err("Missing column in select", ets);
100                    }
101                    (None, None) => {
102                        panic!("ICE")
103                    }
104                }
105            }
106        }
107    }
108
109    let mut guard = typer_stack(
110        typer,
111        |t| core::mem::take(&mut t.reference_types),
112        |t, v| t.reference_types = v,
113    );
114    let typer = &mut guard.typer;
115
116    if let Some(s) = typer.schemas.schemas.get(table.value) {
117        let mut columns = Vec::new();
118        for c in &s.columns {
119            columns.push((c.identifier.clone(), c.type_.clone()));
120        }
121        for v in &typer.reference_types {
122            if v.name == Some(table.clone()) {
123                typer
124                    .issues
125                    .err("Duplicate definitions", table)
126                    .frag("Already defined here", &v.span);
127            }
128        }
129        typer.reference_types.push(ReferenceType {
130            name: Some(table.clone()),
131            span: table.span(),
132            columns,
133        });
134    }
135
136    if let Some(set) = &ior.set {
137        for InsertReplaceSetPair { column, value, .. } in &set.pairs {
138            let mut cnt = 0;
139            let mut t = None;
140            for r in &typer.reference_types {
141                for c in &r.columns {
142                    if c.0 == *column {
143                        cnt += 1;
144                        t = Some(c.clone());
145                    }
146                }
147            }
148            if cnt > 1 {
149                type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
150                let mut issue = typer.issues.err("Ambiguous reference", column);
151                for r in &typer.reference_types {
152                    for c in &r.columns {
153                        if c.0 == *column {
154                            issue.frag("Defined here", &r.span);
155                        }
156                    }
157                }
158            } else if let Some(t) = t {
159                let value_type =
160                    type_expression(typer, value, ExpressionFlags::default(), t.1.base());
161                if typer.matched_type(&value_type, &t.1).is_none() {
162                    typer.err(format!("Got type {} expected {}", value_type, t.1), value);
163                } else if let Type::Args(_, args) = &value_type.t {
164                    for (idx, arg_type, _) in args.iter() {
165                        typer.constrain_arg(*idx, arg_type, &t.1);
166                    }
167                }
168            } else {
169                type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
170                typer.err("Unknown identifier", column);
171            }
172        }
173    }
174
175    if let Some(up) = &ior.on_duplicate_key_update {
176        for InsertReplaceSetPair { value, column, .. } in &up.pairs {
177            let mut cnt = 0;
178            let mut t = None;
179            for r in &typer.reference_types {
180                for c in &r.columns {
181                    if c.0 == *column {
182                        cnt += 1;
183                        t = Some(c.clone());
184                    }
185                }
186            }
187            let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
188            if cnt > 1 {
189                type_expression(typer, value, flags, BaseType::Any);
190                let mut issue = typer.issues.err("Ambiguous reference", column);
191                for r in &typer.reference_types {
192                    for c in &r.columns {
193                        if c.0 == *column {
194                            issue.frag("Defined here", &r.span);
195                        }
196                    }
197                }
198            } else if let Some(t) = t {
199                let value_type = type_expression(typer, value, flags, t.1.base());
200                if typer.matched_type(&value_type, &t.1).is_none() {
201                    typer.err(format!("Got type {} expected {}", value_type, t.1), value);
202                } else if let Type::Args(_, args) = &value_type.t {
203                    for (idx, arg_type, _) in args.iter() {
204                        typer.constrain_arg(*idx, arg_type, &t.1);
205                    }
206                }
207            } else {
208                type_expression(typer, value, flags, BaseType::Any);
209                typer.err("Unknown identifier", column);
210            }
211        }
212    }
213
214    if let Some(on_conflict) = &ior.on_conflict {
215        match &on_conflict.target {
216            sql_parse::OnConflictTarget::Column { name } => {
217                let mut t = None;
218                for r in &typer.reference_types {
219                    for c in &r.columns {
220                        if c.0 == *name {
221                            t = Some(c.clone());
222                        }
223                    }
224                }
225                if t.is_none() {
226                    typer.err("Unknown identifier", name);
227                }
228                //TODO check if there is a unique constraint on column
229            }
230            sql_parse::OnConflictTarget::OnConstraint {
231                on_constraint_span, ..
232            } => {
233                issue_todo!(typer.issues, on_constraint_span);
234            }
235            sql_parse::OnConflictTarget::None => (),
236        }
237
238        match &on_conflict.action {
239            sql_parse::OnConflictAction::DoNothing(_) => (),
240            sql_parse::OnConflictAction::DoUpdateSet { sets, where_, .. } => {
241                for (key, value) in sets {
242                    let mut cnt = 0;
243                    let mut t = None;
244                    for r in &typer.reference_types {
245                        for c in &r.columns {
246                            if c.0 == *key {
247                                cnt += 1;
248                                t = Some(c.clone());
249                            }
250                        }
251                    }
252                    let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
253                    if cnt > 1 {
254                        type_expression(typer, value, flags, BaseType::Any);
255                        let mut issue = typer.issues.err("Ambiguous reference", key);
256                        for r in &typer.reference_types {
257                            for c in &r.columns {
258                                if c.0 == *key {
259                                    issue.frag("Defined here", &r.span);
260                                }
261                            }
262                        }
263                    } else if let Some(t) = t {
264                        let value_type = type_expression(typer, value, flags, t.1.base());
265                        if typer.matched_type(&value_type, &t.1).is_none() {
266                            typer.err(format!("Got type {} expected {}", value_type, t.1), value);
267                        } else if let Type::Args(_, args) = &value_type.t {
268                            for (idx, arg_type, _) in args.iter() {
269                                typer.constrain_arg(*idx, arg_type, &t.1);
270                            }
271                        }
272                    } else {
273                        type_expression(typer, value, flags, BaseType::Any);
274                        typer.err("Unknown identifier", key);
275                    }
276                }
277                if let Some((_, where_)) = where_ {
278                    type_expression(typer, where_, ExpressionFlags::default(), BaseType::Bool);
279                }
280            }
281        }
282    }
283
284    let returning_select = match &ior.returning {
285        Some((returning_span, returning_exprs)) => {
286            let columns = type_select_exprs(typer, returning_exprs, true)
287                .into_iter()
288                .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
289                .collect();
290            Some(SelectType {
291                columns,
292                select_span: returning_span.join_span(returning_exprs),
293            })
294        }
295        None => None,
296    };
297
298    core::mem::drop(guard);
299
300    let auto_increment_id = if auto_increment && matches!(ior.type_, InsertReplaceType::Insert(_)) {
301        if ior
302            .flags
303            .iter()
304            .any(|f| matches!(f, InsertReplaceFlag::Ignore(_)))
305            || ior.on_duplicate_key_update.is_some()
306        {
307            AutoIncrementId::Optional
308        } else {
309            AutoIncrementId::Yes
310        }
311    } else {
312        AutoIncrementId::No
313    };
314
315    (auto_increment_id, returning_select)
316}