1use alloc::{format, vec::Vec};
14use sql_parse::{
15 issue_todo, InsertReplace, InsertReplaceFlag, InsertReplaceSetPair, InsertReplaceType, Spanned,
16};
17
18use crate::{
19 type_expression::{type_expression, ExpressionFlags},
20 type_select::{type_select, type_select_exprs, SelectType},
21 typer::{typer_stack, unqualified_name, ReferenceType, Typer},
22 BaseType, SelectTypeColumn, Type,
23};
24
25#[derive(Clone, Debug, PartialEq, Eq)]
27pub enum AutoIncrementId {
28 Yes,
29 No,
30 Optional,
31}
32
33pub(crate) fn type_insert_replace<'a>(
34 typer: &mut Typer<'a, '_>,
35 ior: &InsertReplace<'a>,
36) -> (AutoIncrementId, Option<SelectType<'a>>) {
37 let table = unqualified_name(typer.issues, &ior.table);
38 let columns = &ior.columns;
39
40 let (s, auto_increment) = if let Some(schema) = typer.schemas.schemas.get(table.value) {
41 if schema.view {
42 typer.err("Inserts into views not yet implemented", table);
43 }
44 let mut col_types = Vec::new();
45
46 for col in columns {
47 if let Some(schema_col) = schema.get_column(col.value) {
48 col_types.push((schema_col.type_.clone(), col.span()));
49 } else {
50 typer.err("No such column in schema", col);
51 }
52 }
53 (
54 Some(col_types),
55 schema.columns.iter().any(|c| c.auto_increment),
56 )
57 } else {
58 typer.err("Unknown table", table);
59 (None, false)
60 };
61
62 if let Some(values) = &ior.values {
63 for row in &values.1 {
64 for (j, e) in row.iter().enumerate() {
65 if let Some((et, ets)) = s.as_ref().and_then(|v| v.get(j)) {
66 let t = type_expression(typer, e, ExpressionFlags::default(), et.base());
67 if typer.matched_type(&t, et).is_none() {
68 typer
69 .err(format!("Got type {}", t.t), e)
70 .frag(format!("Expected {}", et.t), ets);
71 } else if let Type::Args(_, args) = &t.t {
72 for (idx, arg_type, _) in args.iter() {
73 typer.constrain_arg(*idx, arg_type, et);
74 }
75 }
76 } else {
77 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
78 }
79 }
80 }
81 }
82
83 if let Some(select) = &ior.select {
84 let select = type_select(typer, select, true);
85 if let Some(s) = s {
86 for i in 0..usize::max(s.len(), select.columns.len()) {
87 match (s.get(i), select.columns.get(i)) {
88 (Some((et, ets)), Some(t)) => {
89 if typer.matched_type(&t.type_, et).is_none() {
90 typer
91 .err(format!("Got type {}", t.type_.t), &t.span)
92 .frag(format!("Expected {}", et.t), ets);
93 }
94 }
95 (None, Some(t)) => {
96 typer.err("Column in select not in insert", &t.span);
97 }
98 (Some((_, ets)), None) => {
99 typer.err("Missing column in select", ets);
100 }
101 (None, None) => {
102 panic!("ICE")
103 }
104 }
105 }
106 }
107 }
108
109 let mut guard = typer_stack(
110 typer,
111 |t| core::mem::take(&mut t.reference_types),
112 |t, v| t.reference_types = v,
113 );
114 let typer = &mut guard.typer;
115
116 if let Some(s) = typer.schemas.schemas.get(table.value) {
117 let mut columns = Vec::new();
118 for c in &s.columns {
119 columns.push((c.identifier.clone(), c.type_.clone()));
120 }
121 for v in &typer.reference_types {
122 if v.name == Some(table.clone()) {
123 typer
124 .issues
125 .err("Duplicate definitions", table)
126 .frag("Already defined here", &v.span);
127 }
128 }
129 typer.reference_types.push(ReferenceType {
130 name: Some(table.clone()),
131 span: table.span(),
132 columns,
133 });
134 }
135
136 if let Some(set) = &ior.set {
137 for InsertReplaceSetPair { column, value, .. } in &set.pairs {
138 let mut cnt = 0;
139 let mut t = None;
140 for r in &typer.reference_types {
141 for c in &r.columns {
142 if c.0 == *column {
143 cnt += 1;
144 t = Some(c.clone());
145 }
146 }
147 }
148 if cnt > 1 {
149 type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
150 let mut issue = typer.issues.err("Ambiguous reference", column);
151 for r in &typer.reference_types {
152 for c in &r.columns {
153 if c.0 == *column {
154 issue.frag("Defined here", &r.span);
155 }
156 }
157 }
158 } else if let Some(t) = t {
159 let value_type =
160 type_expression(typer, value, ExpressionFlags::default(), t.1.base());
161 if typer.matched_type(&value_type, &t.1).is_none() {
162 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
163 } else if let Type::Args(_, args) = &value_type.t {
164 for (idx, arg_type, _) in args.iter() {
165 typer.constrain_arg(*idx, arg_type, &t.1);
166 }
167 }
168 } else {
169 type_expression(typer, value, ExpressionFlags::default(), BaseType::Any);
170 typer.err("Unknown identifier", column);
171 }
172 }
173 }
174
175 if let Some(up) = &ior.on_duplicate_key_update {
176 for InsertReplaceSetPair { value, column, .. } in &up.pairs {
177 let mut cnt = 0;
178 let mut t = None;
179 for r in &typer.reference_types {
180 for c in &r.columns {
181 if c.0 == *column {
182 cnt += 1;
183 t = Some(c.clone());
184 }
185 }
186 }
187 let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
188 if cnt > 1 {
189 type_expression(typer, value, flags, BaseType::Any);
190 let mut issue = typer.issues.err("Ambiguous reference", column);
191 for r in &typer.reference_types {
192 for c in &r.columns {
193 if c.0 == *column {
194 issue.frag("Defined here", &r.span);
195 }
196 }
197 }
198 } else if let Some(t) = t {
199 let value_type = type_expression(typer, value, flags, t.1.base());
200 if typer.matched_type(&value_type, &t.1).is_none() {
201 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
202 } else if let Type::Args(_, args) = &value_type.t {
203 for (idx, arg_type, _) in args.iter() {
204 typer.constrain_arg(*idx, arg_type, &t.1);
205 }
206 }
207 } else {
208 type_expression(typer, value, flags, BaseType::Any);
209 typer.err("Unknown identifier", column);
210 }
211 }
212 }
213
214 if let Some(on_conflict) = &ior.on_conflict {
215 match &on_conflict.target {
216 sql_parse::OnConflictTarget::Column { name } => {
217 let mut t = None;
218 for r in &typer.reference_types {
219 for c in &r.columns {
220 if c.0 == *name {
221 t = Some(c.clone());
222 }
223 }
224 }
225 if t.is_none() {
226 typer.err("Unknown identifier", name);
227 }
228 }
230 sql_parse::OnConflictTarget::OnConstraint {
231 on_constraint_span, ..
232 } => {
233 issue_todo!(typer.issues, on_constraint_span);
234 }
235 sql_parse::OnConflictTarget::None => (),
236 }
237
238 match &on_conflict.action {
239 sql_parse::OnConflictAction::DoNothing(_) => (),
240 sql_parse::OnConflictAction::DoUpdateSet { sets, where_, .. } => {
241 for (key, value) in sets {
242 let mut cnt = 0;
243 let mut t = None;
244 for r in &typer.reference_types {
245 for c in &r.columns {
246 if c.0 == *key {
247 cnt += 1;
248 t = Some(c.clone());
249 }
250 }
251 }
252 let flags = ExpressionFlags::default().with_in_on_duplicate_key_update(true);
253 if cnt > 1 {
254 type_expression(typer, value, flags, BaseType::Any);
255 let mut issue = typer.issues.err("Ambiguous reference", key);
256 for r in &typer.reference_types {
257 for c in &r.columns {
258 if c.0 == *key {
259 issue.frag("Defined here", &r.span);
260 }
261 }
262 }
263 } else if let Some(t) = t {
264 let value_type = type_expression(typer, value, flags, t.1.base());
265 if typer.matched_type(&value_type, &t.1).is_none() {
266 typer.err(format!("Got type {} expected {}", value_type, t.1), value);
267 } else if let Type::Args(_, args) = &value_type.t {
268 for (idx, arg_type, _) in args.iter() {
269 typer.constrain_arg(*idx, arg_type, &t.1);
270 }
271 }
272 } else {
273 type_expression(typer, value, flags, BaseType::Any);
274 typer.err("Unknown identifier", key);
275 }
276 }
277 if let Some((_, where_)) = where_ {
278 type_expression(typer, where_, ExpressionFlags::default(), BaseType::Bool);
279 }
280 }
281 }
282 }
283
284 let returning_select = match &ior.returning {
285 Some((returning_span, returning_exprs)) => {
286 let columns = type_select_exprs(typer, returning_exprs, true)
287 .into_iter()
288 .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
289 .collect();
290 Some(SelectType {
291 columns,
292 select_span: returning_span.join_span(returning_exprs),
293 })
294 }
295 None => None,
296 };
297
298 core::mem::drop(guard);
299
300 let auto_increment_id = if auto_increment && matches!(ior.type_, InsertReplaceType::Insert(_)) {
301 if ior
302 .flags
303 .iter()
304 .any(|f| matches!(f, InsertReplaceFlag::Ignore(_)))
305 || ior.on_duplicate_key_update.is_some()
306 {
307 AutoIncrementId::Optional
308 } else {
309 AutoIncrementId::Yes
310 }
311 } else {
312 AutoIncrementId::No
313 };
314
315 (auto_increment_id, returning_select)
316}