sql_type/
type_insert_replace.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::{format, vec::Vec};
14use sql_parse::{
15    issue_todo, InsertReplace, InsertReplaceFlag, InsertReplaceSetPair, InsertReplaceType,
16    OptSpanned, Spanned,
17};
18
19use crate::{
20    type_expression::{type_expression, ExpressionFlags},
21    type_select::{type_select, type_select_exprs, SelectType},
22    typer::{typer_stack, unqualified_name, ReferenceType, Typer},
23    BaseType, SelectTypeColumn, Type,
24};
25
26/// Does the insert yield an auto increment id
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum AutoIncrementId {
29    Yes,
30    No,
31    Optional,
32}
33
34pub(crate) fn type_insert_replace<'a>(
35    typer: &mut Typer<'a, '_>,
36    ior: &InsertReplace<'a>,
37) -> (AutoIncrementId, Option<SelectType<'a>>) {
38    let table = unqualified_name(typer.issues, &ior.table);
39    let columns = &ior.columns;
40
41    let (s, auto_increment) = if let Some(schema) = typer.schemas.schemas.get(table.value) {
42        if schema.view {
43            typer.err("Inserts into views not yet implemented", table);
44        }
45        let mut col_types = Vec::new();
46
47        for col in columns {
48            if let Some(schema_col) = schema.get_column(col.value) {
49                col_types.push((schema_col.type_.clone(), col.span()));
50            } else {
51                typer.err("No such column in schema", col);
52            }
53        }
54
55        if let Some(set) = &ior.set {
56            for col in &schema.columns {
57                if col.auto_increment
58                    || col.default
59                    || !col.type_.not_null
60                    || col.as_.is_some()
61                    || set.pairs.iter().any(|v| v.column==col.identifier)
62                {
63                    continue;
64                }
65                typer.err(
66                    format!(
67                        "No value for column {} provided, but it has no default value",
68                        &col.identifier
69                    ),
70                    set
71                );
72            }
73        } else {
74            for col in &schema.columns {
75                if col.auto_increment
76                    || col.default
77                    || !col.type_.not_null
78                    || col.as_.is_some()
79                    || columns.contains(&col.identifier)
80                {
81                    continue;
82                }
83                typer.err(
84                    format!(
85                        "No value for column {} provided, but it has no default value",
86                        &col.identifier
87                    ),
88                    &columns.opt_span().unwrap()
89                );
90            }
91        }
92
93        (
94            Some(col_types),
95            schema.columns.iter().any(|c| c.auto_increment),
96        )
97    } else {
98        typer.err("Unknown table", table);
99        (None, false)
100    };
101
102    if let Some(values) = &ior.values {
103        for row in &values.1 {
104            for (j, e) in row.iter().enumerate() {
105                if let Some((et, ets)) = s.as_ref().and_then(|v| v.get(j)) {
106                    let t = type_expression(typer, e, ExpressionFlags::default(), et.base());
107                    if typer.matched_type(&t, et).is_none() {
108                        typer
109                            .err(format!("Got type {}", t.t), e)
110                            .frag(format!("Expected {}", et.t), ets);
111                    } else if let Type::Args(_, args) = &t.t {
112                        for (idx, arg_type, _) in args.iter() {
113                            typer.constrain_arg(*idx, arg_type, et);
114                        }
115                    }
116                } else {
117                    type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
118                }
119            }
120            if let Some(s) = &s {
121                if s.len() != row.len() {
122                    typer
123                        .err(
124                            format!("Got {} columns", row.len()),
125                            &row.opt_span().unwrap(),
126                        )
127                        .frag(
128                            format!("Expected {}", columns.len()),
129                            &columns.opt_span().unwrap(),
130                        );
131                }
132            }
133        }
134    }
135
136    if let Some(select) = &ior.select {
137        let select = type_select(typer, select, true);
138        if let Some(s) = s {
139            for i in 0..usize::max(s.len(), select.columns.len()) {
140                match (s.get(i), select.columns.get(i)) {
141                    (Some((et, ets)), Some(t)) => {
142                        if typer.matched_type(&t.type_, et).is_none() {
143                            typer
144                                .err(format!("Got type {}", t.type_.t), &t.span)
145                                .frag(format!("Expected {}", et.t), ets);
146                        }
147                    }
148                    (None, Some(t)) => {
149                        typer.err("Column in select not in insert", &t.span);
150                    }
151                    (Some((_, ets)), None) => {
152                        typer.err("Missing column in select", ets);
153                    }
154                    (None, None) => {
155                        panic!("ICE")
156                    }
157                }
158            }
159        }
160    }
161
162    let mut guard = typer_stack(
163        typer,
164        |t| core::mem::take(&mut t.reference_types),
165        |t, v| t.reference_types = v,
166    );
167    let typer = &mut guard.typer;
168
169    if let Some(s) = typer.schemas.schemas.get(table.value) {
170        let mut columns = Vec::new();
171        for c in &s.columns {
172            columns.push((c.identifier.clone(), c.type_.clone()));
173        }
174        for v in &typer.reference_types {
175            if v.name == Some(table.clone()) {
176                typer
177                    .issues
178                    .err("Duplicate definitions", table)
179                    .frag("Already defined here", &v.span);
180            }
181        }
182        typer.reference_types.push(ReferenceType {
183            name: Some(table.clone()),
184            span: table.span(),
185            columns,
186        });
187    }
188
189    if let Some(set) = &ior.set {
190        for InsertReplaceSetPair { column, value, .. } in &set.pairs {
191            let mut cnt = 0;
192            let mut t = None;
193            for r in &typer.reference_types {
194                for c in &r.columns {
195                    if c.0 == *column {
196                        cnt += 1;
197                        t = Some(c.clone());
198                    }
199                }
200            }
201            if cnt > 1 {
202                type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
203                let mut issue = typer.issues.err("Ambiguous reference", column);
204                for r in &typer.reference_types {
205                    for c in &r.columns {
206                        if c.0 == *column {
207                            issue.frag("Defined here", &r.span);
208                        }
209                    }
210                }
211            } else if let Some(t) = t {
212                let value_type =
213                    type_expression(typer, value, ExpressionFlags::default(), t.1.base());
214                if typer.matched_type(&value_type, &t.1).is_none() {
215                    typer.err(format!("Got type {} expected {}", value_type, t.1), value);
216                } else if let Type::Args(_, args) = &value_type.t {
217                    for (idx, arg_type, _) in args.iter() {
218                        typer.constrain_arg(*idx, arg_type, &t.1);
219                    }
220                }
221            } else {
222                type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
223                typer.err("Unknown identifier", column);
224            }
225        }
226    }
227
228    if let Some(up) = &ior.on_duplicate_key_update {
229        for InsertReplaceSetPair { value, column, .. } in &up.pairs {
230            let mut cnt = 0;
231            let mut t = None;
232            for r in &typer.reference_types {
233                for c in &r.columns {
234                    if c.0 == *column {
235                        cnt += 1;
236                        t = Some(c.clone());
237                    }
238                }
239            }
240            let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
241            if cnt > 1 {
242                type_expression(typer, value, flags, BaseType::Any);
243                let mut issue = typer.issues.err("Ambiguous reference", column);
244                for r in &typer.reference_types {
245                    for c in &r.columns {
246                        if c.0 == *column {
247                            issue.frag("Defined here", &r.span);
248                        }
249                    }
250                }
251            } else if let Some(t) = t {
252                let value_type = type_expression(typer, value, flags, t.1.base());
253                if typer.matched_type(&value_type, &t.1).is_none() {
254                    typer.err(format!("Got type {} expected {}", value_type, t.1), value);
255                } else if let Type::Args(_, args) = &value_type.t {
256                    for (idx, arg_type, _) in args.iter() {
257                        typer.constrain_arg(*idx, arg_type, &t.1);
258                    }
259                }
260            } else {
261                type_expression(typer, value, flags, BaseType::Any);
262                typer.err("Unknown identifier", column);
263            }
264        }
265    }
266
267    if let Some(on_conflict) = &ior.on_conflict {
268        match &on_conflict.target {
269            sql_parse::OnConflictTarget::Column { name } => {
270                let mut t = None;
271                for r in &typer.reference_types {
272                    for c in &r.columns {
273                        if c.0 == *name {
274                            t = Some(c.clone());
275                        }
276                    }
277                }
278                if t.is_none() {
279                    typer.err("Unknown identifier", name);
280                }
281                //TODO check if there is a unique constraint on column
282            }
283            sql_parse::OnConflictTarget::OnConstraint {
284                on_constraint_span, ..
285            } => {
286                issue_todo!(typer.issues, on_constraint_span);
287            }
288            sql_parse::OnConflictTarget::None => (),
289        }
290
291        match &on_conflict.action {
292            sql_parse::OnConflictAction::DoNothing(_) => (),
293            sql_parse::OnConflictAction::DoUpdateSet { sets, where_, .. } => {
294                for (key, value) in sets {
295                    let mut cnt = 0;
296                    let mut t = None;
297                    for r in &typer.reference_types {
298                        for c in &r.columns {
299                            if c.0 == *key {
300                                cnt += 1;
301                                t = Some(c.clone());
302                            }
303                        }
304                    }
305                    let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
306                    if cnt > 1 {
307                        type_expression(typer, value, flags, BaseType::Any);
308                        let mut issue = typer.issues.err("Ambiguous reference", key);
309                        for r in &typer.reference_types {
310                            for c in &r.columns {
311                                if c.0 == *key {
312                                    issue.frag("Defined here", &r.span);
313                                }
314                            }
315                        }
316                    } else if let Some(t) = t {
317                        let value_type = type_expression(typer, value, flags, t.1.base());
318                        if typer.matched_type(&value_type, &t.1).is_none() {
319                            typer.err(format!("Got type {} expected {}", value_type, t.1), value);
320                        } else if let Type::Args(_, args) = &value_type.t {
321                            for (idx, arg_type, _) in args.iter() {
322                                typer.constrain_arg(*idx, arg_type, &t.1);
323                            }
324                        }
325                    } else {
326                        type_expression(typer, value, flags, BaseType::Any);
327                        typer.err("Unknown identifier", key);
328                    }
329                }
330                if let Some((_, where_)) = where_ {
331                    type_expression(typer, where_, ExpressionFlags::default(), BaseType::Bool);
332                }
333            }
334        }
335    }
336
337    let returning_select = match &ior.returning {
338        Some((returning_span, returning_exprs)) => {
339            let columns = type_select_exprs(typer, returning_exprs, true)
340                .into_iter()
341                .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
342                .collect();
343            Some(SelectType {
344                columns,
345                select_span: returning_span.join_span(returning_exprs),
346            })
347        }
348        None => None,
349    };
350
351    core::mem::drop(guard);
352
353    let auto_increment_id = if auto_increment && matches!(ior.type_, InsertReplaceType::Insert(_)) {
354        if ior
355            .flags
356            .iter()
357            .any(|f| matches!(f, InsertReplaceFlag::Ignore(_)))
358            || ior.on_duplicate_key_update.is_some()
359        {
360            AutoIncrementId::Optional
361        } else {
362            AutoIncrementId::Yes
363        }
364    } else {
365        AutoIncrementId::No
366    };
367
368    (auto_increment_id, returning_select)
369}