1use alloc::{format, vec::Vec};
14use sql_parse::{
15 issue_todo, InsertReplace, InsertReplaceFlag, InsertReplaceSetPair, InsertReplaceType,
16 OptSpanned, Spanned,
17};
18
19use crate::{
20 type_expression::{type_expression, ExpressionFlags},
21 type_select::{type_select, type_select_exprs, SelectType},
22 typer::{typer_stack, unqualified_name, ReferenceType, Typer},
23 BaseType, SelectTypeColumn, Type,
24};
25
26#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum AutoIncrementId {
29 Yes,
30 No,
31 Optional,
32}
33
34pub(crate) fn type_insert_replace<'a>(
35 typer: &mut Typer<'a, '_>,
36 ior: &InsertReplace<'a>,
37) -> (AutoIncrementId, Option<SelectType<'a>>) {
38 let table = unqualified_name(typer.issues, &ior.table);
39 let columns = &ior.columns;
40
41 let (s, auto_increment) = if let Some(schema) = typer.schemas.schemas.get(table.value) {
42 if schema.view {
43 typer.err("Inserts into views not yet implemented", table);
44 }
45 let mut col_types = Vec::new();
46
47 for col in columns {
48 if let Some(schema_col) = schema.get_column(col.value) {
49 col_types.push((schema_col.type_.clone(), col.span()));
50 } else {
51 typer.err("No such column in schema", col);
52 }
53 }
54
55 if let Some(set) = &ior.set {
56 for col in &schema.columns {
57 if col.auto_increment
58 || col.default
59 || !col.type_.not_null
60 || col.as_.is_some()
61 || set.pairs.iter().any(|v| v.column==col.identifier)
62 {
63 continue;
64 }
65 typer.err(
66 format!(
67 "No value for column {} provided, but it has no default value",
68 &col.identifier
69 ),
70 set
71 );
72 }
73 } else {
74 for col in &schema.columns {
75 if col.auto_increment
76 || col.default
77 || !col.type_.not_null
78 || col.as_.is_some()
79 || columns.contains(&col.identifier)
80 {
81 continue;
82 }
83 typer.err(
84 format!(
85 "No value for column {} provided, but it has no default value",
86 &col.identifier
87 ),
88 &columns.opt_span().unwrap()
89 );
90 }
91 }
92
93 (
94 Some(col_types),
95 schema.columns.iter().any(|c| c.auto_increment),
96 )
97 } else {
98 typer.err("Unknown table", table);
99 (None, false)
100 };
101
102 if let Some(values) = &ior.values {
103 for row in &values.1 {
104 for (j, e) in row.iter().enumerate() {
105 if let Some((et, ets)) = s.as_ref().and_then(|v| v.get(j)) {
106 let t = type_expression(typer, e, ExpressionFlags::default(), et.base());
107 if typer.matched_type(&t, et).is_none() {
108 typer
109 .err(format!("Got type {}", t.t), e)
110 .frag(format!("Expected {}", et.t), ets);
111 } else if let Type::Args(_, args) = &t.t {
112 for (idx, arg_type, _) in args.iter() {
113 typer.constrain_arg(*idx, arg_type, et);
114 }
115 }
116 } else {
117 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
118 }
119 }
120 if let Some(s) = &s {
121 if s.len() != row.len() {
122 typer
123 .err(
124 format!("Got {} columns", row.len()),
125 &row.opt_span().unwrap(),
126 )
127 .frag(
128 format!("Expected {}", columns.len()),
129 &columns.opt_span().unwrap(),
130 );
131 }
132 }
133 }
134 }
135
136 if let Some(select) = &ior.select {
137 let select = type_select(typer, select, true);
138 if let Some(s) = s {
139 for i in 0..usize::max(s.len(), select.columns.len()) {
140 match (s.get(i), select.columns.get(i)) {
141 (Some((et, ets)), Some(t)) => {
142 if typer.matched_type(&t.type_, et).is_none() {
143 typer
144 .err(format!("Got type {}", t.type_.t), &t.span)
145 .frag(format!("Expected {}", et.t), ets);
146 }
147 }
148 (None, Some(t)) => {
149 typer.err("Column in select not in insert", &t.span);
150 }
151 (Some((_, ets)), None) => {
152 typer.err("Missing column in select", ets);
153 }
154 (None, None) => {
155 panic!("ICE")
156 }
157 }
158 }
159 }
160 }
161
162 let mut guard = typer_stack(
163 typer,
164 |t| core::mem::take(&mut t.reference_types),
165 |t, v| t.reference_types = v,
166 );
167 let typer = &mut guard.typer;
168
169 if let Some(s) = typer.schemas.schemas.get(table.value) {
170 let mut columns = Vec::new();
171 for c in &s.columns {
172 columns.push((c.identifier.clone(), c.type_.clone()));
173 }
174 for v in &typer.reference_types {
175 if v.name == Some(table.clone()) {
176 typer
177 .issues
178 .err("Duplicate definitions", table)
179 .frag("Already defined here", &v.span);
180 }
181 }
182 typer.reference_types.push(ReferenceType {
183 name: Some(table.clone()),
184 span: table.span(),
185 columns,
186 });
187 }
188
189 if let Some(set) = &ior.set {
190 for InsertReplaceSetPair { column, value, .. } in &set.pairs {
191 let mut cnt = 0;
192 let mut t = None;
193 for r in &typer.reference_types {
194 for c in &r.columns {
195 if c.0 == *column {
196 cnt += 1;
197 t = Some(c.clone());
198 }
199 }
200 }
201 if cnt > 1 {
202 type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
203 let mut issue = typer.issues.err("Ambiguous reference", column);
204 for r in &typer.reference_types {
205 for c in &r.columns {
206 if c.0 == *column {
207 issue.frag("Defined here", &r.span);
208 }
209 }
210 }
211 } else if let Some(t) = t {
212 let value_type =
213 type_expression(typer, value, ExpressionFlags::default(), t.1.base());
214 if typer.matched_type(&value_type, &t.1).is_none() {
215 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
216 } else if let Type::Args(_, args) = &value_type.t {
217 for (idx, arg_type, _) in args.iter() {
218 typer.constrain_arg(*idx, arg_type, &t.1);
219 }
220 }
221 } else {
222 type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
223 typer.err("Unknown identifier", column);
224 }
225 }
226 }
227
228 if let Some(up) = &ior.on_duplicate_key_update {
229 for InsertReplaceSetPair { value, column, .. } in &up.pairs {
230 let mut cnt = 0;
231 let mut t = None;
232 for r in &typer.reference_types {
233 for c in &r.columns {
234 if c.0 == *column {
235 cnt += 1;
236 t = Some(c.clone());
237 }
238 }
239 }
240 let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
241 if cnt > 1 {
242 type_expression(typer, value, flags, BaseType::Any);
243 let mut issue = typer.issues.err("Ambiguous reference", column);
244 for r in &typer.reference_types {
245 for c in &r.columns {
246 if c.0 == *column {
247 issue.frag("Defined here", &r.span);
248 }
249 }
250 }
251 } else if let Some(t) = t {
252 let value_type = type_expression(typer, value, flags, t.1.base());
253 if typer.matched_type(&value_type, &t.1).is_none() {
254 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
255 } else if let Type::Args(_, args) = &value_type.t {
256 for (idx, arg_type, _) in args.iter() {
257 typer.constrain_arg(*idx, arg_type, &t.1);
258 }
259 }
260 } else {
261 type_expression(typer, value, flags, BaseType::Any);
262 typer.err("Unknown identifier", column);
263 }
264 }
265 }
266
267 if let Some(on_conflict) = &ior.on_conflict {
268 match &on_conflict.target {
269 sql_parse::OnConflictTarget::Column { name } => {
270 let mut t = None;
271 for r in &typer.reference_types {
272 for c in &r.columns {
273 if c.0 == *name {
274 t = Some(c.clone());
275 }
276 }
277 }
278 if t.is_none() {
279 typer.err("Unknown identifier", name);
280 }
281 }
283 sql_parse::OnConflictTarget::OnConstraint {
284 on_constraint_span, ..
285 } => {
286 issue_todo!(typer.issues, on_constraint_span);
287 }
288 sql_parse::OnConflictTarget::None => (),
289 }
290
291 match &on_conflict.action {
292 sql_parse::OnConflictAction::DoNothing(_) => (),
293 sql_parse::OnConflictAction::DoUpdateSet { sets, where_, .. } => {
294 for (key, value) in sets {
295 let mut cnt = 0;
296 let mut t = None;
297 for r in &typer.reference_types {
298 for c in &r.columns {
299 if c.0 == *key {
300 cnt += 1;
301 t = Some(c.clone());
302 }
303 }
304 }
305 let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
306 if cnt > 1 {
307 type_expression(typer, value, flags, BaseType::Any);
308 let mut issue = typer.issues.err("Ambiguous reference", key);
309 for r in &typer.reference_types {
310 for c in &r.columns {
311 if c.0 == *key {
312 issue.frag("Defined here", &r.span);
313 }
314 }
315 }
316 } else if let Some(t) = t {
317 let value_type = type_expression(typer, value, flags, t.1.base());
318 if typer.matched_type(&value_type, &t.1).is_none() {
319 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
320 } else if let Type::Args(_, args) = &value_type.t {
321 for (idx, arg_type, _) in args.iter() {
322 typer.constrain_arg(*idx, arg_type, &t.1);
323 }
324 }
325 } else {
326 type_expression(typer, value, flags, BaseType::Any);
327 typer.err("Unknown identifier", key);
328 }
329 }
330 if let Some((_, where_)) = where_ {
331 type_expression(typer, where_, ExpressionFlags::default(), BaseType::Bool);
332 }
333 }
334 }
335 }
336
337 let returning_select = match &ior.returning {
338 Some((returning_span, returning_exprs)) => {
339 let columns = type_select_exprs(typer, returning_exprs, true)
340 .into_iter()
341 .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
342 .collect();
343 Some(SelectType {
344 columns,
345 select_span: returning_span.join_span(returning_exprs),
346 })
347 }
348 None => None,
349 };
350
351 core::mem::drop(guard);
352
353 let auto_increment_id = if auto_increment && matches!(ior.type_, InsertReplaceType::Insert(_)) {
354 if ior
355 .flags
356 .iter()
357 .any(|f| matches!(f, InsertReplaceFlag::Ignore(_)))
358 || ior.on_duplicate_key_update.is_some()
359 {
360 AutoIncrementId::Optional
361 } else {
362 AutoIncrementId::Yes
363 }
364 } else {
365 AutoIncrementId::No
366 };
367
368 (auto_increment_id, returning_select)
369}