1use alloc::{format, vec::Vec};
14use sql_parse::{
15 issue_todo, InsertReplace, InsertReplaceFlag, InsertReplaceSetPair, InsertReplaceType,
16 OptSpanned, Spanned,
17};
18
19use crate::{
20 type_expression::{type_expression, ExpressionFlags},
21 type_select::{type_select, type_select_exprs, SelectType},
22 typer::{typer_stack, unqualified_name, ReferenceType, Typer},
23 BaseType, SelectTypeColumn, Type,
24};
25
26#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum AutoIncrementId {
29 Yes,
30 No,
31 Optional,
32}
33
34pub(crate) fn type_insert_replace<'a>(
35 typer: &mut Typer<'a, '_>,
36 ior: &InsertReplace<'a>,
37) -> (AutoIncrementId, Option<SelectType<'a>>) {
38 let table = unqualified_name(typer.issues, &ior.table);
39 let columns = &ior.columns;
40
41 let (s, auto_increment) = if let Some(schema) = typer.schemas.schemas.get(table.value) {
42 if schema.view {
43 typer.err("Inserts into views not yet implemented", table);
44 }
45 let mut col_types = Vec::new();
46
47 for col in columns {
48 if let Some(schema_col) = schema.get_column(col.value) {
49 col_types.push((schema_col.type_.clone(), col.span()));
50 } else {
51 typer.err("No such column in schema", col);
52 }
53 }
54
55 if let Some(set) = &ior.set {
56 for col in &schema.columns {
57 if col.auto_increment
58 || col.default
59 || !col.type_.not_null
60 || col.as_.is_some()
61 || col.generated
62 || set.pairs.iter().any(|v| v.column == col.identifier)
63 {
64 continue;
65 }
66 typer.err(
67 format!(
68 "No value for column {} provided, but it has no default value",
69 &col.identifier
70 ),
71 set,
72 );
73 }
74 } else {
75 for col in &schema.columns {
76 if col.auto_increment
77 || col.default
78 || !col.type_.not_null
79 || col.as_.is_some()
80 || col.generated
81 || columns.contains(&col.identifier)
82 {
83 continue;
84 }
85 typer.err(
86 format!(
87 "No value for column {} provided, but it has no default value",
88 &col.identifier
89 ),
90 &columns.opt_span().unwrap(),
91 );
92 }
93 }
94
95 (
96 Some(col_types),
97 schema.columns.iter().any(|c| c.auto_increment),
98 )
99 } else {
100 typer.err("Unknown table", table);
101 (None, false)
102 };
103
104 if let Some(values) = &ior.values {
105 for row in &values.1 {
106 for (j, e) in row.iter().enumerate() {
107 if let Some((et, ets)) = s.as_ref().and_then(|v| v.get(j)) {
108 let t = type_expression(typer, e, ExpressionFlags::default(), et.base());
109 if typer.matched_type(&t, et).is_none() {
110 typer
111 .err(format!("Got type {}", t.t), e)
112 .frag(format!("Expected {}", et.t), ets);
113 } else if let Type::Args(_, args) = &t.t {
114 for (idx, arg_type, _) in args.iter() {
115 typer.constrain_arg(*idx, arg_type, et);
116 }
117 }
118 } else {
119 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
120 }
121 }
122 if let Some(s) = &s {
123 if s.len() != row.len() {
124 typer
125 .err(
126 format!("Got {} columns", row.len()),
127 &row.opt_span().unwrap(),
128 )
129 .frag(
130 format!("Expected {}", columns.len()),
131 &columns.opt_span().unwrap(),
132 );
133 }
134 }
135 }
136 }
137
138 if let Some(select) = &ior.select {
139 let select = type_select(typer, select, true);
140 if let Some(s) = s {
141 for i in 0..usize::max(s.len(), select.columns.len()) {
142 match (s.get(i), select.columns.get(i)) {
143 (Some((et, ets)), Some(t)) => {
144 if typer.matched_type(&t.type_, et).is_none() {
145 typer
146 .err(format!("Got type {}", t.type_.t), &t.span)
147 .frag(format!("Expected {}", et.t), ets);
148 }
149 }
150 (None, Some(t)) => {
151 typer.err("Column in select not in insert", &t.span);
152 }
153 (Some((_, ets)), None) => {
154 typer.err("Missing column in select", ets);
155 }
156 (None, None) => {
157 panic!("ICE")
158 }
159 }
160 }
161 }
162 }
163
164 let mut guard = typer_stack(
165 typer,
166 |t| core::mem::take(&mut t.reference_types),
167 |t, v| t.reference_types = v,
168 );
169 let typer = &mut guard.typer;
170
171 if let Some(s) = typer.schemas.schemas.get(table.value) {
172 let mut columns = Vec::new();
173 for c in &s.columns {
174 columns.push((c.identifier.clone(), c.type_.clone()));
175 }
176 for v in &typer.reference_types {
177 if v.name == Some(table.clone()) {
178 typer
179 .issues
180 .err("Duplicate definitions", table)
181 .frag("Already defined here", &v.span);
182 }
183 }
184 typer.reference_types.push(ReferenceType {
185 name: Some(table.clone()),
186 span: table.span(),
187 columns,
188 });
189 }
190
191 if let Some(set) = &ior.set {
192 for InsertReplaceSetPair { column, value, .. } in &set.pairs {
193 let mut cnt = 0;
194 let mut t = None;
195 for r in &typer.reference_types {
196 for c in &r.columns {
197 if c.0 == *column {
198 cnt += 1;
199 t = Some(c.clone());
200 }
201 }
202 }
203 if cnt > 1 {
204 type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
205 let mut issue = typer.issues.err("Ambiguous reference", column);
206 for r in &typer.reference_types {
207 for c in &r.columns {
208 if c.0 == *column {
209 issue.frag("Defined here", &r.span);
210 }
211 }
212 }
213 } else if let Some(t) = t {
214 let value_type =
215 type_expression(typer, value, ExpressionFlags::default(), t.1.base());
216 if typer.matched_type(&value_type, &t.1).is_none() {
217 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
218 } else if let Type::Args(_, args) = &value_type.t {
219 for (idx, arg_type, _) in args.iter() {
220 typer.constrain_arg(*idx, arg_type, &t.1);
221 }
222 }
223 } else {
224 type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
225 typer.err("Unknown identifier", column);
226 }
227 }
228 }
229
230 if let Some(up) = &ior.on_duplicate_key_update {
231 for InsertReplaceSetPair { value, column, .. } in &up.pairs {
232 let mut cnt = 0;
233 let mut t = None;
234 for r in &typer.reference_types {
235 for c in &r.columns {
236 if c.0 == *column {
237 cnt += 1;
238 t = Some(c.clone());
239 }
240 }
241 }
242 let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
243 if cnt > 1 {
244 type_expression(typer, value, flags, BaseType::Any);
245 let mut issue = typer.issues.err("Ambiguous reference", column);
246 for r in &typer.reference_types {
247 for c in &r.columns {
248 if c.0 == *column {
249 issue.frag("Defined here", &r.span);
250 }
251 }
252 }
253 } else if let Some(t) = t {
254 let value_type = type_expression(typer, value, flags, t.1.base());
255 if typer.matched_type(&value_type, &t.1).is_none() {
256 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
257 } else if let Type::Args(_, args) = &value_type.t {
258 for (idx, arg_type, _) in args.iter() {
259 typer.constrain_arg(*idx, arg_type, &t.1);
260 }
261 }
262 } else {
263 type_expression(typer, value, flags, BaseType::Any);
264 typer.err("Unknown identifier", column);
265 }
266 }
267 }
268
269 if let Some(on_conflict) = &ior.on_conflict {
270 match &on_conflict.target {
271 sql_parse::OnConflictTarget::Column { name } => {
272 let mut t = None;
273 for r in &typer.reference_types {
274 for c in &r.columns {
275 if c.0 == *name {
276 t = Some(c.clone());
277 }
278 }
279 }
280 if t.is_none() {
281 typer.err("Unknown identifier", name);
282 }
283 }
285 sql_parse::OnConflictTarget::OnConstraint {
286 on_constraint_span, ..
287 } => {
288 issue_todo!(typer.issues, on_constraint_span);
289 }
290 sql_parse::OnConflictTarget::None => (),
291 }
292
293 match &on_conflict.action {
294 sql_parse::OnConflictAction::DoNothing(_) => (),
295 sql_parse::OnConflictAction::DoUpdateSet { sets, where_, .. } => {
296 for (key, value) in sets {
297 let mut cnt = 0;
298 let mut t = None;
299 for r in &typer.reference_types {
300 for c in &r.columns {
301 if c.0 == *key {
302 cnt += 1;
303 t = Some(c.clone());
304 }
305 }
306 }
307 let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
308 if cnt > 1 {
309 type_expression(typer, value, flags, BaseType::Any);
310 let mut issue = typer.issues.err("Ambiguous reference", key);
311 for r in &typer.reference_types {
312 for c in &r.columns {
313 if c.0 == *key {
314 issue.frag("Defined here", &r.span);
315 }
316 }
317 }
318 } else if let Some(t) = t {
319 let value_type = type_expression(typer, value, flags, t.1.base());
320 if typer.matched_type(&value_type, &t.1).is_none() {
321 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
322 } else if let Type::Args(_, args) = &value_type.t {
323 for (idx, arg_type, _) in args.iter() {
324 typer.constrain_arg(*idx, arg_type, &t.1);
325 }
326 }
327 } else {
328 type_expression(typer, value, flags, BaseType::Any);
329 typer.err("Unknown identifier", key);
330 }
331 }
332 if let Some((_, where_)) = where_ {
333 type_expression(typer, where_, ExpressionFlags::default(), BaseType::Bool);
334 }
335 }
336 }
337 }
338
339 let returning_select = match &ior.returning {
340 Some((returning_span, returning_exprs)) => {
341 let columns = type_select_exprs(typer, returning_exprs, true)
342 .into_iter()
343 .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
344 .collect();
345 Some(SelectType {
346 columns,
347 select_span: returning_span.join_span(returning_exprs),
348 })
349 }
350 None => None,
351 };
352
353 core::mem::drop(guard);
354
355 let auto_increment_id = if auto_increment && matches!(ior.type_, InsertReplaceType::Insert(_)) {
356 if ior
357 .flags
358 .iter()
359 .any(|f| matches!(f, InsertReplaceFlag::Ignore(_)))
360 || ior.on_duplicate_key_update.is_some()
361 {
362 AutoIncrementId::Optional
363 } else {
364 AutoIncrementId::Yes
365 }
366 } else {
367 AutoIncrementId::No
368 };
369
370 (auto_increment_id, returning_select)
371}