1use alloc::{format, vec::Vec};
14use sql_parse::{
15 InsertReplace, InsertReplaceFlag, InsertReplaceSetPair, InsertReplaceType, OptSpanned, Spanned,
16 issue_todo,
17};
18
19use crate::{
20 BaseType, SelectTypeColumn, Type,
21 type_expression::{ExpressionFlags, type_expression},
22 type_select::{SelectType, type_select, type_select_exprs},
23 typer::{ReferenceType, Typer, typer_stack, unqualified_name},
24};
25
26#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum AutoIncrementId {
29 Yes,
30 No,
31 Optional,
32}
33
34pub(crate) fn type_insert_replace<'a>(
35 typer: &mut Typer<'a, '_>,
36 ior: &InsertReplace<'a>,
37) -> (AutoIncrementId, Option<SelectType<'a>>) {
38 let table = unqualified_name(typer.issues, &ior.table);
39 let columns = &ior.columns;
40
41 let (s, auto_increment) = if let Some(schema) = typer.schemas.schemas.get(table.value) {
42 if schema.view {
43 typer.err("Inserts into views not yet implemented", table);
44 }
45 let mut col_types = Vec::new();
46
47 for col in columns {
48 if let Some(schema_col) = schema.get_column(col.value) {
49 col_types.push((schema_col.type_.clone(), col.span()));
50 } else {
51 typer.err("No such column in schema", col);
52 }
53 }
54
55 if let Some(set) = &ior.set {
56 for col in &schema.columns {
57 if col.auto_increment
58 || col.default
59 || !col.type_.not_null
60 || col.as_.is_some()
61 || col.generated
62 || set.pairs.iter().any(|v| v.column == col.identifier)
63 {
64 continue;
65 }
66 typer.err(
67 format!(
68 "No value for column {} provided, but it has no default value",
69 &col.identifier
70 ),
71 set,
72 );
73 }
74 } else {
75 for col in &schema.columns {
76 if col.auto_increment
77 || col.default
78 || !col.type_.not_null
79 || col.as_.is_some()
80 || col.generated
81 || columns.contains(&col.identifier)
82 {
83 continue;
84 }
85 typer.err(
86 format!(
87 "No value for column {} provided, but it has no default value",
88 &col.identifier
89 ),
90 &columns.opt_span().unwrap(),
91 );
92 }
93 }
94
95 (
96 Some(col_types),
97 schema.columns.iter().any(|c| c.auto_increment),
98 )
99 } else {
100 typer.err("Unknown table", table);
101 (None, false)
102 };
103
104 if let Some(values) = &ior.values {
105 for row in &values.1 {
106 for (j, e) in row.iter().enumerate() {
107 if let Some((et, ets)) = s.as_ref().and_then(|v| v.get(j)) {
108 let t = type_expression(typer, e, ExpressionFlags::default(), et.base());
109 if typer.matched_type(&t, et).is_none() {
110 typer
111 .err(format!("Got type {}", t.t), e)
112 .frag(format!("Expected {}", et.t), ets);
113 } else if let Type::Args(_, args) = &t.t {
114 for (idx, arg_type, _) in args.iter() {
115 typer.constrain_arg(*idx, arg_type, et);
116 }
117 }
118 } else {
119 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
120 }
121 }
122 if let Some(s) = &s
123 && s.len() != row.len() {
124 typer
125 .err(
126 format!("Got {} columns", row.len()),
127 &row.opt_span().unwrap(),
128 )
129 .frag(
130 format!("Expected {}", columns.len()),
131 &columns.opt_span().unwrap(),
132 );
133 }
134 }
135 }
136
137 if let Some(select) = &ior.select {
138 let select = type_select(typer, select, true);
139 if let Some(s) = s {
140 for i in 0..usize::max(s.len(), select.columns.len()) {
141 match (s.get(i), select.columns.get(i)) {
142 (Some((et, ets)), Some(t)) => {
143 if typer.matched_type(&t.type_, et).is_none() {
144 typer
145 .err(format!("Got type {}", t.type_.t), &t.span)
146 .frag(format!("Expected {}", et.t), ets);
147 }
148 }
149 (None, Some(t)) => {
150 typer.err("Column in select not in insert", &t.span);
151 }
152 (Some((_, ets)), None) => {
153 typer.err("Missing column in select", ets);
154 }
155 (None, None) => {
156 panic!("ICE")
157 }
158 }
159 }
160 }
161 }
162
163 let mut guard = typer_stack(
164 typer,
165 |t| core::mem::take(&mut t.reference_types),
166 |t, v| t.reference_types = v,
167 );
168 let typer = &mut guard.typer;
169
170 if let Some(s) = typer.schemas.schemas.get(table.value) {
171 let mut columns = Vec::new();
172 for c in &s.columns {
173 columns.push((c.identifier.clone(), c.type_.clone()));
174 }
175 for v in &typer.reference_types {
176 if v.name == Some(table.clone()) {
177 typer
178 .issues
179 .err("Duplicate definitions", table)
180 .frag("Already defined here", &v.span);
181 }
182 }
183 typer.reference_types.push(ReferenceType {
184 name: Some(table.clone()),
185 span: table.span(),
186 columns,
187 });
188 }
189
190 if let Some(set) = &ior.set {
191 for InsertReplaceSetPair { column, value, .. } in &set.pairs {
192 let mut cnt = 0;
193 let mut t = None;
194 for r in &typer.reference_types {
195 for c in &r.columns {
196 if c.0 == *column {
197 cnt += 1;
198 t = Some(c.clone());
199 }
200 }
201 }
202 if cnt > 1 {
203 type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
204 let mut issue = typer.issues.err("Ambiguous reference", column);
205 for r in &typer.reference_types {
206 for c in &r.columns {
207 if c.0 == *column {
208 issue.frag("Defined here", &r.span);
209 }
210 }
211 }
212 } else if let Some(t) = t {
213 let value_type =
214 type_expression(typer, value, ExpressionFlags::default(), t.1.base());
215 if typer.matched_type(&value_type, &t.1).is_none() {
216 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
217 } else if let Type::Args(_, args) = &value_type.t {
218 for (idx, arg_type, _) in args.iter() {
219 typer.constrain_arg(*idx, arg_type, &t.1);
220 }
221 }
222 } else {
223 type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
224 typer.err("Unknown identifier", column);
225 }
226 }
227 }
228
229 if let Some(up) = &ior.on_duplicate_key_update {
230 for InsertReplaceSetPair { value, column, .. } in &up.pairs {
231 let mut cnt = 0;
232 let mut t = None;
233 for r in &typer.reference_types {
234 for c in &r.columns {
235 if c.0 == *column {
236 cnt += 1;
237 t = Some(c.clone());
238 }
239 }
240 }
241 let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
242 if cnt > 1 {
243 type_expression(typer, value, flags, BaseType::Any);
244 let mut issue = typer.issues.err("Ambiguous reference", column);
245 for r in &typer.reference_types {
246 for c in &r.columns {
247 if c.0 == *column {
248 issue.frag("Defined here", &r.span);
249 }
250 }
251 }
252 } else if let Some(t) = t {
253 let value_type = type_expression(typer, value, flags, t.1.base());
254 if typer.matched_type(&value_type, &t.1).is_none() {
255 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
256 } else if let Type::Args(_, args) = &value_type.t {
257 for (idx, arg_type, _) in args.iter() {
258 typer.constrain_arg(*idx, arg_type, &t.1);
259 }
260 }
261 } else {
262 type_expression(typer, value, flags, BaseType::Any);
263 typer.err("Unknown identifier", column);
264 }
265 }
266 }
267
268 if let Some(on_conflict) = &ior.on_conflict {
269 match &on_conflict.target {
270 sql_parse::OnConflictTarget::Column { name } => {
271 let mut t = None;
272 for r in &typer.reference_types {
273 for c in &r.columns {
274 if c.0 == *name {
275 t = Some(c.clone());
276 }
277 }
278 }
279 if t.is_none() {
280 typer.err("Unknown identifier", name);
281 }
282 }
284 sql_parse::OnConflictTarget::OnConstraint {
285 on_constraint_span, ..
286 } => {
287 issue_todo!(typer.issues, on_constraint_span);
288 }
289 sql_parse::OnConflictTarget::None => (),
290 }
291
292 match &on_conflict.action {
293 sql_parse::OnConflictAction::DoNothing(_) => (),
294 sql_parse::OnConflictAction::DoUpdateSet { sets, where_, .. } => {
295 for (key, value) in sets {
296 let mut cnt = 0;
297 let mut t = None;
298 for r in &typer.reference_types {
299 for c in &r.columns {
300 if c.0 == *key {
301 cnt += 1;
302 t = Some(c.clone());
303 }
304 }
305 }
306 let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
307 if cnt > 1 {
308 type_expression(typer, value, flags, BaseType::Any);
309 let mut issue = typer.issues.err("Ambiguous reference", key);
310 for r in &typer.reference_types {
311 for c in &r.columns {
312 if c.0 == *key {
313 issue.frag("Defined here", &r.span);
314 }
315 }
316 }
317 } else if let Some(t) = t {
318 let value_type = type_expression(typer, value, flags, t.1.base());
319 if typer.matched_type(&value_type, &t.1).is_none() {
320 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
321 } else if let Type::Args(_, args) = &value_type.t {
322 for (idx, arg_type, _) in args.iter() {
323 typer.constrain_arg(*idx, arg_type, &t.1);
324 }
325 }
326 } else {
327 type_expression(typer, value, flags, BaseType::Any);
328 typer.err("Unknown identifier", key);
329 }
330 }
331 if let Some((_, where_)) = where_ {
332 type_expression(typer, where_, ExpressionFlags::default(), BaseType::Bool);
333 }
334 }
335 }
336 }
337
338 let returning_select = match &ior.returning {
339 Some((returning_span, returning_exprs)) => {
340 let columns = type_select_exprs(typer, returning_exprs, true)
341 .into_iter()
342 .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
343 .collect();
344 Some(SelectType {
345 columns,
346 select_span: returning_span.join_span(returning_exprs),
347 })
348 }
349 None => None,
350 };
351
352 core::mem::drop(guard);
353
354 let auto_increment_id = if auto_increment && matches!(ior.type_, InsertReplaceType::Insert(_)) {
355 if ior
356 .flags
357 .iter()
358 .any(|f| matches!(f, InsertReplaceFlag::Ignore(_)))
359 || ior.on_duplicate_key_update.is_some()
360 {
361 AutoIncrementId::Optional
362 } else {
363 AutoIncrementId::Yes
364 }
365 } else {
366 AutoIncrementId::No
367 };
368
369 (auto_increment_id, returning_select)
370}