1use alloc::{format, vec::Vec};
14use sql_parse::{
15 Identifier, InsertReplace, InsertReplaceFlag, InsertReplaceSetPair, InsertReplaceType,
16 OptSpanned, Spanned, issue_todo,
17};
18
19use crate::{
20 BaseType, SelectTypeColumn, Type,
21 type_expression::{ExpressionFlags, type_expression},
22 type_select::{SelectType, type_select, type_select_exprs},
23 typer::{ReferenceType, Typer, typer_stack, unqualified_name},
24};
25
26#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum AutoIncrementId {
29 Yes,
30 No,
31 Optional,
32}
33
34pub(crate) fn type_insert_replace<'a>(
35 typer: &mut Typer<'a, '_>,
36 ior: &InsertReplace<'a>,
37) -> (AutoIncrementId, Option<SelectType<'a>>) {
38 let table = unqualified_name(typer.issues, &ior.table);
39 let columns = &ior.columns;
40
41 let (s, auto_increment) = if let Some(schema) = typer.schemas.schemas.get(table.value) {
42 if schema.view {
43 typer.err("Inserts into views not yet implemented", table);
44 }
45 let mut col_types = Vec::new();
46
47 for col in columns {
48 if let Some(schema_col) = schema.get_column(col.value) {
49 col_types.push((schema_col.type_.clone(), col.span()));
50 } else {
51 typer.err("No such column in schema", col);
52 }
53 }
54
55 if let Some(set) = &ior.set {
56 for col in &schema.columns {
57 if col.auto_increment
58 || col.default
59 || !col.type_.not_null
60 || col.as_.is_some()
61 || col.generated
62 || set.pairs.iter().any(|v| v.column == col.identifier)
63 {
64 continue;
65 }
66 typer.err(
67 format!(
68 "No value for column {} provided, but it has no default value",
69 &col.identifier
70 ),
71 set,
72 );
73 }
74 } else {
75 for col in &schema.columns {
76 if col.auto_increment
77 || col.default
78 || !col.type_.not_null
79 || col.as_.is_some()
80 || col.generated
81 || columns.contains(&col.identifier)
82 {
83 continue;
84 }
85 typer.err(
86 format!(
87 "No value for column {} provided, but it has no default value",
88 &col.identifier
89 ),
90 &columns.opt_span().unwrap(),
91 );
92 }
93 }
94
95 (
96 Some(col_types),
97 schema.columns.iter().any(|c| c.auto_increment),
98 )
99 } else {
100 typer.err("Unknown table", table);
101 (None, false)
102 };
103
104 if let Some(values) = &ior.values {
105 for row in &values.1 {
106 for (j, e) in row.iter().enumerate() {
107 if let Some((et, ets)) = s.as_ref().and_then(|v| v.get(j)) {
108 let t = type_expression(typer, e, ExpressionFlags::default(), et.base());
109 if typer.matched_type(&t, et).is_none() {
110 typer
111 .err(format!("Got type {}", t.t), e)
112 .frag(format!("Expected {}", et.t), ets);
113 } else if let Type::Args(_, args) = &t.t {
114 for (idx, arg_type, _) in args.iter() {
115 typer.constrain_arg(*idx, arg_type, et);
116 }
117 }
118 } else {
119 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
120 }
121 }
122 if let Some(s) = &s
123 && s.len() != row.len()
124 {
125 typer
126 .err(
127 format!("Got {} columns", row.len()),
128 &row.opt_span().unwrap(),
129 )
130 .frag(
131 format!("Expected {}", columns.len()),
132 &columns.opt_span().unwrap(),
133 );
134 }
135 }
136 }
137
138 if let Some(select) = &ior.select {
139 let select = type_select(typer, select, true);
140 if let Some(s) = s {
141 for i in 0..usize::max(s.len(), select.columns.len()) {
142 match (s.get(i), select.columns.get(i)) {
143 (Some((et, ets)), Some(t)) => {
144 if typer.matched_type(&t.type_, et).is_none() {
145 typer
146 .err(format!("Got type {}", t.type_.t), &t.span)
147 .frag(format!("Expected {}", et.t), ets);
148 }
149 }
150 (None, Some(t)) => {
151 typer.err("Column in select not in insert", &t.span);
152 }
153 (Some((_, ets)), None) => {
154 typer.err("Missing column in select", ets);
155 }
156 (None, None) => {
157 panic!("ICE")
158 }
159 }
160 }
161 }
162 }
163
164 let mut guard = typer_stack(
165 typer,
166 |t| core::mem::take(&mut t.reference_types),
167 |t, v| t.reference_types = v,
168 );
169 let typer = &mut guard.typer;
170
171 if let Some(s) = typer.schemas.schemas.get(table.value) {
172 let mut columns = Vec::new();
173 for c in &s.columns {
174 columns.push((c.identifier.clone(), c.type_.clone()));
175 }
176 for v in &typer.reference_types {
177 if v.name == Some(table.clone()) {
178 typer
179 .issues
180 .err("Duplicate definitions", table)
181 .frag("Already defined here", &v.span);
182 }
183 }
184 typer.reference_types.push(ReferenceType {
185 name: Some(table.clone()),
186 span: table.span(),
187 columns,
188 });
189 }
190
191 if let Some(set) = &ior.set {
192 for InsertReplaceSetPair { column, value, .. } in &set.pairs {
193 let mut cnt = 0;
194 let mut t = None;
195 for r in &typer.reference_types {
196 for c in &r.columns {
197 if c.0 == *column {
198 cnt += 1;
199 t = Some(c.clone());
200 }
201 }
202 }
203 if cnt > 1 {
204 type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
205 let mut issue = typer.issues.err("Ambiguous reference", column);
206 for r in &typer.reference_types {
207 for c in &r.columns {
208 if c.0 == *column {
209 issue.frag("Defined here", &r.span);
210 }
211 }
212 }
213 } else if let Some(t) = t {
214 let value_type =
215 type_expression(typer, value, ExpressionFlags::default(), t.1.base());
216 if typer.matched_type(&value_type, &t.1).is_none() {
217 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
218 } else if let Type::Args(_, args) = &value_type.t {
219 for (idx, arg_type, _) in args.iter() {
220 typer.constrain_arg(*idx, arg_type, &t.1);
221 }
222 }
223 } else {
224 type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
225 typer.err("Unknown identifier", column);
226 }
227 }
228 }
229
230 if let Some(up) = &ior.on_duplicate_key_update {
231 for InsertReplaceSetPair { value, column, .. } in &up.pairs {
232 let mut cnt = 0;
233 let mut t = None;
234 for r in &typer.reference_types {
235 for c in &r.columns {
236 if c.0 == *column {
237 cnt += 1;
238 t = Some(c.clone());
239 }
240 }
241 }
242 let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
243 if cnt > 1 {
244 type_expression(typer, value, flags, BaseType::Any);
245 let mut issue = typer.issues.err("Ambiguous reference", column);
246 for r in &typer.reference_types {
247 for c in &r.columns {
248 if c.0 == *column {
249 issue.frag("Defined here", &r.span);
250 }
251 }
252 }
253 } else if let Some(t) = t {
254 let value_type = type_expression(typer, value, flags, t.1.base());
255 if typer.matched_type(&value_type, &t.1).is_none() {
256 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
257 } else if let Type::Args(_, args) = &value_type.t {
258 for (idx, arg_type, _) in args.iter() {
259 typer.constrain_arg(*idx, arg_type, &t.1);
260 }
261 }
262 } else {
263 type_expression(typer, value, flags, BaseType::Any);
264 typer.err("Unknown identifier", column);
265 }
266 }
267 }
268
269 if let Some(on_conflict) = &ior.on_conflict {
270 match &on_conflict.target {
271 sql_parse::OnConflictTarget::Columns { names } => {
272 for name in names {
273 let mut t = None;
274 for r in &typer.reference_types {
275 for c in &r.columns {
276 if c.0 == *name {
277 t = Some(c.clone());
278 }
279 }
280 }
281 if t.is_none() {
282 typer.err("Unknown identifier", name);
283 }
284 }
285 }
287 sql_parse::OnConflictTarget::OnConstraint {
288 on_constraint_span, ..
289 } => {
290 issue_todo!(typer.issues, on_constraint_span);
291 }
292 sql_parse::OnConflictTarget::None => (),
293 }
294
295 match &on_conflict.action {
296 sql_parse::OnConflictAction::DoNothing(_) => (),
297 sql_parse::OnConflictAction::DoUpdateSet {
298 sets,
299 where_,
300 do_update_set_span,
301 } => {
302 let mut excluded_columns = Vec::new();
303 if let Some(schema) = typer.schemas.schemas.get(table.value)
304 && !schema.view
305 {
306 for col in &ior.columns {
307 if let Some(schema_col) = schema.get_column(col.value) {
308 excluded_columns
309 .push((schema_col.identifier.clone(), schema_col.type_.clone()));
310 }
311 }
312 }
313
314 typer.reference_types.push(ReferenceType {
315 name: Some(Identifier::new("EXCLUDED", do_update_set_span.clone())),
316 span: do_update_set_span.span(),
317 columns: excluded_columns,
318 });
319
320 for (key, value) in sets {
321 let mut cnt = 0;
322 let mut t = None;
323 for r in &typer.reference_types {
324 if let Some(name) = &r.name
325 && name.value == "EXCLUDED"
326 {
327 continue;
328 }
329 for c in &r.columns {
330 if c.0 == *key {
331 cnt += 1;
332 t = Some(c.clone());
333 }
334 }
335 }
336 let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
337 if cnt > 1 {
338 type_expression(typer, value, flags, BaseType::Any);
339 let mut issue = typer.issues.err("Ambiguous reference", key);
340 for r in &typer.reference_types {
341 for c in &r.columns {
342 if c.0 == *key {
343 issue.frag("Defined here", &r.span);
344 }
345 }
346 }
347 } else if let Some(t) = t {
348 let value_type = type_expression(typer, value, flags, t.1.base());
349 if typer.matched_type(&value_type, &t.1).is_none() {
350 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
351 } else if let Type::Args(_, args) = &value_type.t {
352 for (idx, arg_type, _) in args.iter() {
353 typer.constrain_arg(*idx, arg_type, &t.1);
354 }
355 }
356 } else {
357 type_expression(typer, value, flags, BaseType::Any);
358 typer.err("Unknown identifier", key);
359 }
360 }
361 if let Some((_, where_)) = where_ {
362 type_expression(typer, where_, ExpressionFlags::default(), BaseType::Bool);
363 }
364 }
365 }
366 }
367
368 let returning_select = match &ior.returning {
369 Some((returning_span, returning_exprs)) => {
370 let columns = type_select_exprs(typer, returning_exprs, true)
371 .into_iter()
372 .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
373 .collect();
374 Some(SelectType {
375 columns,
376 select_span: returning_span.join_span(returning_exprs),
377 })
378 }
379 None => None,
380 };
381
382 core::mem::drop(guard);
383
384 let auto_increment_id = if auto_increment && matches!(ior.type_, InsertReplaceType::Insert(_)) {
385 if ior
386 .flags
387 .iter()
388 .any(|f| matches!(f, InsertReplaceFlag::Ignore(_)))
389 || ior.on_duplicate_key_update.is_some()
390 {
391 AutoIncrementId::Optional
392 } else {
393 AutoIncrementId::Yes
394 }
395 } else {
396 AutoIncrementId::No
397 };
398
399 (auto_increment_id, returning_select)
400}