1use alloc::{format, vec::Vec};
14use sql_parse::{
15 issue_ice, issue_todo, Expression, Identifier, IdentifierPart, Issues, OptSpanned, Select,
16 SelectExpr, Span, Spanned, Statement, Union,
17};
18
19use crate::{
20 type_::{BaseType, FullType},
21 type_expression::{type_expression, ExpressionFlags},
22 type_reference::type_reference,
23 typer::{typer_stack, ReferenceType, Typer},
24 Type,
25};
26
27#[derive(Debug, Clone)]
29pub struct SelectTypeColumn<'a> {
30 pub name: Option<Identifier<'a>>,
32 pub type_: FullType<'a>,
34 pub span: Span,
36}
37
38impl<'a> Spanned for SelectTypeColumn<'a> {
39 fn span(&self) -> Span {
40 self.span.span()
41 }
42}
43
44#[derive(Debug, Clone)]
45pub(crate) struct SelectType<'a> {
46 pub columns: Vec<SelectTypeColumn<'a>>,
47 pub select_span: Span,
48}
49
50impl<'a> Spanned for SelectType<'a> {
51 fn span(&self) -> Span {
52 self.columns
53 .opt_span()
54 .unwrap_or_else(|| self.select_span.clone())
55 }
56}
57
58pub(crate) fn resolve_kleene_identifier<'a, 'b>(
59 typer: &mut Typer<'a, 'b>,
60 parts: &[IdentifierPart<'a>],
61 as_: &Option<Identifier<'a>>,
62 mut cb: impl FnMut(&mut Issues<'a>, Option<Identifier<'a>>, FullType<'a>, Span, bool),
63) {
64 match parts {
65 [sql_parse::IdentifierPart::Name(col)] => {
66 let mut cnt = 0;
67 let mut t = None;
68 for r in &typer.reference_types {
69 for c in &r.columns {
70 if c.0 == *col {
71 cnt += 1;
72 t = Some(c);
73 }
74 }
75 }
76 let name = as_.as_ref().unwrap_or(col);
77 if cnt > 1 {
78 let mut issue = typer.issues.err("Ambigious reference", col);
79 for r in &typer.reference_types {
80 for c in &r.columns {
81 if c.0 == *col {
82 issue.frag("Defined here", &r.span);
83 }
84 }
85 }
86 cb(
87 typer.issues,
88 Some(name.clone()),
89 FullType::invalid(),
90 name.span(),
91 as_.is_some(),
92 );
93 } else if let Some(t) = t {
94 cb(
95 typer.issues,
96 Some(name.clone()),
97 t.1.clone(),
98 name.span(),
99 as_.is_some(),
100 );
101 } else {
102 typer.err("Unknown identifier", col);
103 cb(
104 typer.issues,
105 Some(name.clone()),
106 FullType::invalid(),
107 name.span(),
108 as_.is_some(),
109 );
110 }
111 }
112 [sql_parse::IdentifierPart::Star(v)] => {
113 if let Some(as_) = as_ {
114 typer.err("As not supported for *", as_);
115 }
116 for r in &typer.reference_types {
117 for c in &r.columns {
118 cb(
119 typer.issues,
120 Some(c.0.clone()),
121 c.1.clone(),
122 v.clone(),
123 false,
124 );
125 }
126 }
127 }
128 [sql_parse::IdentifierPart::Name(tbl), sql_parse::IdentifierPart::Name(col)] => {
129 let mut t = None;
130 for r in &typer.reference_types {
131 if r.name == Some(tbl.clone()) {
132 for c in &r.columns {
133 if c.0 == *col {
134 t = Some(c);
135 }
136 }
137 }
138 }
139 let name = as_.as_ref().unwrap_or(col);
140 if let Some(t) = t {
141 cb(
142 typer.issues,
143 Some(name.clone()),
144 t.1.clone(),
145 name.span(),
146 as_.is_some(),
147 );
148 } else {
149 typer.err("Unknown identifier", col);
150 cb(
151 typer.issues,
152 Some(name.clone()),
153 FullType::invalid(),
154 name.span(),
155 as_.is_some(),
156 );
157 }
158 }
159 [sql_parse::IdentifierPart::Name(tbl), sql_parse::IdentifierPart::Star(v)] => {
160 if let Some(as_) = as_ {
161 typer.err("As not supported for *", as_);
162 }
163 let mut t = None;
164 for r in &typer.reference_types {
165 if r.name == Some(tbl.clone()) {
166 t = Some(r);
167 }
168 }
169 if let Some(t) = t {
170 for c in &t.columns {
171 cb(
172 typer.issues,
173 Some(c.0.clone()),
174 c.1.clone(),
175 v.clone(),
176 false,
177 );
178 }
179 } else {
180 typer.err("Unknown table", tbl);
181 }
182 }
183 [sql_parse::IdentifierPart::Star(v), _] => {
184 typer.err("Not supported here", v);
185 }
186 _ => {
187 typer.err("Invalid identifier", &parts.opt_span().expect("parts span"));
188 }
189 }
190}
191
192pub(crate) fn type_select<'a>(
193 typer: &mut Typer<'a, '_>,
194 select: &Select<'a>,
195 warn_duplicate: bool,
196) -> SelectType<'a> {
197 let mut guard = typer_stack(
198 typer,
199 |t| t.reference_types.clone(),
200 |t, v| t.reference_types = v,
201 );
202 let typer = &mut guard.typer;
203
204 for flag in &select.flags {
205 match &flag {
206 sql_parse::SelectFlag::All(_) => issue_todo!(typer.issues, flag),
207 sql_parse::SelectFlag::Distinct(_) | sql_parse::SelectFlag::DistinctRow(_) => (),
208 sql_parse::SelectFlag::StraightJoin(_) => issue_todo!(typer.issues, flag),
209 sql_parse::SelectFlag::HighPriority(_)
210 | sql_parse::SelectFlag::SqlSmallResult(_)
211 | sql_parse::SelectFlag::SqlBigResult(_)
212 | sql_parse::SelectFlag::SqlBufferResult(_)
213 | sql_parse::SelectFlag::SqlNoCache(_)
214 | sql_parse::SelectFlag::SqlCalcFoundRows(_) => (),
215 }
216 }
217
218 if let Some(references) = &select.table_references {
219 for reference in references {
220 type_reference(typer, reference, false);
221 }
222 }
223
224 if let Some((where_, _)) = &select.where_ {
225 let t = type_expression(
226 typer,
227 where_,
228 ExpressionFlags::default()
229 .with_not_null(true)
230 .with_true(true),
231 BaseType::Bool,
232 );
233 typer.ensure_base(where_, &t, BaseType::Bool);
234 }
235
236 let result = type_select_exprs(typer, &select.select_exprs, warn_duplicate);
237
238 if let Some((_, group_by)) = &select.group_by {
239 for e in group_by {
240 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
241 }
242 }
243
244 if let Some((_, order_by)) = &select.order_by {
245 for (e, _) in order_by {
246 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
247 }
248 }
249
250 if let Some((having, _)) = &select.having {
251 let t = type_expression(
252 typer,
253 having,
254 ExpressionFlags::default()
255 .with_not_null(true)
256 .with_true(true),
257 BaseType::Bool,
258 );
259 typer.ensure_base(having, &t, BaseType::Bool);
260 }
261
262 if let Some((_, offset, count)) = &select.limit {
263 if let Some(offset) = offset {
264 let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
265 if typer
266 .matched_type(&t, &FullType::new(Type::U64, true))
267 .is_none()
268 {
269 typer.err(format!("Expected integer type got {}", t.t), offset);
270 }
271 }
272 let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
273 if typer
274 .matched_type(&t, &FullType::new(Type::U64, true))
275 .is_none()
276 {
277 typer.err(format!("Expected integer type got {}", t.t), count);
278 }
279 }
280
281 SelectType {
282 columns: result
283 .into_iter()
284 .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
285 .collect(),
286 select_span: select.span(),
287 }
288}
289
290pub(crate) fn type_select_exprs<'a, 'b>(
291 typer: &mut Typer<'a, 'b>,
292 select_exprs: &[SelectExpr<'a>],
293 warn_duplicate: bool,
294) -> Vec<(Option<Identifier<'a>>, FullType<'a>, Span)> {
295 let mut result = Vec::new();
296 let mut select_reference = ReferenceType {
297 name: None,
298 span: select_exprs.opt_span().expect("select_exprs span"),
299 columns: Vec::new(),
300 };
301
302 for e in select_exprs {
303 let mut add_result = |issues: &mut Issues<'a>,
304 name: Option<Identifier<'a>>,
305 type_: FullType<'a>,
306 span: Span,
307 as_: bool| {
308 if let Some(name) = name.clone() {
309 if as_ {
310 select_reference.columns.push((name.clone(), type_.clone()));
311 }
312 for (on, _, os) in &result {
313 if Some(name.clone()) == *on && warn_duplicate {
314 issues
315 .warn("Also defined here", &span)
316 .frag(format!("Multiple columns with the name '{}'", name), os);
317 }
318 }
319 }
320 result.push((name, type_, span));
321 };
322 if let Expression::Identifier(parts) = &e.expr {
323 resolve_kleene_identifier(typer, parts, &e.as_, add_result);
324 } else {
325 let type_ = type_expression(typer, &e.expr, ExpressionFlags::default(), BaseType::Any);
326 if let Some(as_) = &e.as_ {
327 add_result(typer.issues, Some(as_.clone()), type_, as_.span(), true);
328 } else {
329 if typer.options.warn_unnamed_column_in_select {
330 typer.issues.warn("Unnamed column in select", e);
331 }
332 add_result(typer.issues, None, type_, 0..0, false);
333 };
334 }
335 }
336
337 typer.reference_types.push(select_reference);
338
339 result
340}
341
342pub(crate) fn type_union<'a>(typer: &mut Typer<'a, '_>, union: &Union<'a>) -> SelectType<'a> {
343 let mut t = type_union_select(typer, &union.left, true);
344 let mut left = union.left.span();
345 for w in &union.with {
346 let t2 = type_union_select(typer, &w.union_statement, true);
347
348 for i in 0..usize::max(t.columns.len(), t2.columns.len()) {
349 if let Some(l) = t.columns.get_mut(i) {
350 if let Some(r) = t2.columns.get(i) {
351 if l.name != r.name {
352 if let Some(ln) = &l.name {
353 if let Some(rn) = &r.name {
354 typer
355 .err("Incompatible names in union", &w.union_span)
356 .frag(format!("Column {} is named {}", i, ln), &left)
357 .frag(
358 format!("Column {} is named {}", i, rn),
359 &w.union_statement,
360 );
361 } else {
362 typer
363 .err("Incompatible names in union", &w.union_span)
364 .frag(format!("Column {} is named {}", i, ln), &left)
365 .frag(format!("Column {} has no name", i), &w.union_statement);
366 }
367 } else {
368 typer
369 .err("Incompatible names in union", &w.union_span)
370 .frag(format!("Column {} has no name", i), &left)
371 .frag(
372 format!(
373 "Column {} is named {}",
374 i,
375 r.name.as_ref().expect("name")
376 ),
377 &w.union_statement,
378 );
379 }
380 }
381 if l.type_.t == r.type_.t {
382 l.type_ =
383 FullType::new(l.type_.t.clone(), l.type_.not_null && r.type_.not_null);
384 } else if let Some(t) = typer.matched_type(&l.type_, &r.type_) {
385 l.type_ = FullType::new(t, l.type_.not_null && r.type_.not_null);
386 } else {
387 typer
388 .err("Incompatible types in union", &w.union_span)
389 .frag(format!("Column {} is of type {}", i, l.type_.t), &left)
390 .frag(
391 format!("Column {} is of type {}", i, r.type_.t),
392 &w.union_statement,
393 );
394 }
395 } else if let Some(n) = &l.name {
396 typer
397 .err("Incompatible types in union", &w.union_span)
398 .frag(format!("Column {} ({}) only on this side", i, n), &left);
399 } else {
400 typer
401 .err("Incompatible types in union", &w.union_span)
402 .frag(format!("Column {} only on this side", i), &left);
403 }
404 } else if let Some(n) = &t2.columns[i].name {
405 typer
406 .err("Incompatible types in union", &w.union_span)
407 .frag(
408 format!("Column {} ({}) only on this side", i, n),
409 &w.union_statement,
410 );
411 } else {
412 typer
413 .err("Incompatible types in union", &w.union_span)
414 .frag(
415 format!("Column {} only on this side", i),
416 &w.union_statement,
417 );
418 }
419 }
420 left = left.join_span(&w.union_statement);
421 }
422
423 typer.reference_types.push(ReferenceType {
424 name: None,
425 span: t.span(),
426 columns: t
427 .columns
428 .iter()
429 .filter_map(|v| v.name.as_ref().map(|name| (name.clone(), v.type_.clone())))
430 .collect(),
431 });
432
433 if let Some((_, order_by)) = &union.order_by {
434 for (e, _) in order_by {
435 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
436 }
437 }
438
439 if let Some((_, offset, count)) = &union.limit {
440 if let Some(offset) = offset {
441 let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
442 if typer
443 .matched_type(&t, &FullType::new(Type::U64, true))
444 .is_none()
445 {
446 typer.err(format!("Expected integer type got {}", t.t), offset);
447 }
448 }
449 let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
450 if typer
451 .matched_type(&t, &FullType::new(Type::U64, true))
452 .is_none()
453 {
454 typer.err(format!("Expected integer type got {}", t.t), count);
455 }
456 }
457
458 typer.reference_types.pop();
459
460 t
461}
462
463pub(crate) fn type_union_select<'a>(
464 typer: &mut Typer<'a, '_>,
465 statement: &Statement<'a>,
466 warn_duplicate: bool,
467) -> SelectType<'a> {
468 match statement {
469 Statement::Select(s) => type_select(typer, s, warn_duplicate),
470 Statement::Union(u) => type_union(typer, u),
471 s => {
472 issue_ice!(typer.issues, s);
473 SelectType {
474 columns: Vec::new(),
475 select_span: s.span(),
476 }
477 }
478 }
479}