1use alloc::{format, vec::Vec};
14use sql_parse::{
15 Expression, Identifier, IdentifierPart, Issues, OptSpanned, Select, SelectExpr, Span, Spanned,
16 Statement, Union, issue_ice, issue_todo,
17};
18
19use crate::{
20 Type,
21 type_::{BaseType, FullType},
22 type_expression::{ExpressionFlags, type_expression},
23 type_reference::type_reference,
24 typer::{ReferenceType, Typer, typer_stack},
25};
26
27#[derive(Debug, Clone)]
29pub struct SelectTypeColumn<'a> {
30 pub name: Option<Identifier<'a>>,
32 pub type_: FullType<'a>,
34 pub span: Span,
36}
37
38impl<'a> Spanned for SelectTypeColumn<'a> {
39 fn span(&self) -> Span {
40 self.span.span()
41 }
42}
43
44#[derive(Debug, Clone)]
45pub(crate) struct SelectType<'a> {
46 pub columns: Vec<SelectTypeColumn<'a>>,
47 pub select_span: Span,
48}
49
50impl<'a> Spanned for SelectType<'a> {
51 fn span(&self) -> Span {
52 self.columns
53 .opt_span()
54 .unwrap_or_else(|| self.select_span.clone())
55 }
56}
57
58pub(crate) fn resolve_kleene_identifier<'a, 'b>(
59 typer: &mut Typer<'a, 'b>,
60 parts: &[IdentifierPart<'a>],
61 as_: &Option<Identifier<'a>>,
62 mut cb: impl FnMut(&mut Issues<'a>, Option<Identifier<'a>>, FullType<'a>, Span, bool),
63) {
64 match parts {
65 [sql_parse::IdentifierPart::Name(col)] => {
66 let mut cnt = 0;
67 let mut t = None;
68 for r in &typer.reference_types {
69 for c in &r.columns {
70 if c.0 == *col {
71 cnt += 1;
72 t = Some(c);
73 }
74 }
75 }
76 let name = as_.as_ref().unwrap_or(col);
77 if cnt > 1 {
78 let mut issue = typer.issues.err("Ambigious reference", col);
79 for r in &typer.reference_types {
80 for c in &r.columns {
81 if c.0 == *col {
82 issue.frag("Defined here", &r.span);
83 }
84 }
85 }
86 cb(
87 typer.issues,
88 Some(name.clone()),
89 FullType::invalid(),
90 name.span(),
91 as_.is_some(),
92 );
93 } else if let Some(t) = t {
94 cb(
95 typer.issues,
96 Some(name.clone()),
97 t.1.clone(),
98 name.span(),
99 as_.is_some(),
100 );
101 } else {
102 typer.err("Unknown identifier", col);
103 cb(
104 typer.issues,
105 Some(name.clone()),
106 FullType::invalid(),
107 name.span(),
108 as_.is_some(),
109 );
110 }
111 }
112 [sql_parse::IdentifierPart::Star(v)] => {
113 if let Some(as_) = as_ {
114 typer.err("As not supported for *", as_);
115 }
116 for r in &typer.reference_types {
117 for c in &r.columns {
118 cb(
119 typer.issues,
120 Some(c.0.clone()),
121 c.1.clone(),
122 v.clone(),
123 false,
124 );
125 }
126 }
127 }
128 [
129 sql_parse::IdentifierPart::Name(tbl),
130 sql_parse::IdentifierPart::Name(col),
131 ] => {
132 let mut t = None;
133 for r in &typer.reference_types {
134 if r.name == Some(tbl.clone()) {
135 for c in &r.columns {
136 if c.0 == *col {
137 t = Some(c);
138 }
139 }
140 }
141 }
142 let name = as_.as_ref().unwrap_or(col);
143 if let Some(t) = t {
144 cb(
145 typer.issues,
146 Some(name.clone()),
147 t.1.clone(),
148 name.span(),
149 as_.is_some(),
150 );
151 } else {
152 typer.err("Unknown identifier", col);
153 cb(
154 typer.issues,
155 Some(name.clone()),
156 FullType::invalid(),
157 name.span(),
158 as_.is_some(),
159 );
160 }
161 }
162 [
163 sql_parse::IdentifierPart::Name(tbl),
164 sql_parse::IdentifierPart::Star(v),
165 ] => {
166 if let Some(as_) = as_ {
167 typer.err("As not supported for *", as_);
168 }
169 let mut t = None;
170 for r in &typer.reference_types {
171 if r.name == Some(tbl.clone()) {
172 t = Some(r);
173 }
174 }
175 if let Some(t) = t {
176 for c in &t.columns {
177 cb(
178 typer.issues,
179 Some(c.0.clone()),
180 c.1.clone(),
181 v.clone(),
182 false,
183 );
184 }
185 } else {
186 typer.err("Unknown table", tbl);
187 }
188 }
189 [sql_parse::IdentifierPart::Star(v), _] => {
190 typer.err("Not supported here", v);
191 }
192 _ => {
193 typer.err("Invalid identifier", &parts.opt_span().expect("parts span"));
194 }
195 }
196}
197
198pub(crate) fn type_select<'a>(
199 typer: &mut Typer<'a, '_>,
200 select: &Select<'a>,
201 warn_duplicate: bool,
202) -> SelectType<'a> {
203 let mut guard = typer_stack(
204 typer,
205 |t| t.reference_types.clone(),
206 |t, v| t.reference_types = v,
207 );
208 let typer = &mut guard.typer;
209
210 for flag in &select.flags {
211 match &flag {
212 sql_parse::SelectFlag::All(_) => issue_todo!(typer.issues, flag),
213 sql_parse::SelectFlag::Distinct(_) | sql_parse::SelectFlag::DistinctRow(_) => (),
214 sql_parse::SelectFlag::StraightJoin(_) => issue_todo!(typer.issues, flag),
215 sql_parse::SelectFlag::HighPriority(_)
216 | sql_parse::SelectFlag::SqlSmallResult(_)
217 | sql_parse::SelectFlag::SqlBigResult(_)
218 | sql_parse::SelectFlag::SqlBufferResult(_)
219 | sql_parse::SelectFlag::SqlNoCache(_)
220 | sql_parse::SelectFlag::SqlCalcFoundRows(_) => (),
221 }
222 }
223
224 if let Some(references) = &select.table_references {
225 for reference in references {
226 type_reference(typer, reference, false);
227 }
228 }
229
230 if let Some((where_, _)) = &select.where_ {
231 let t = type_expression(
232 typer,
233 where_,
234 ExpressionFlags::default()
235 .with_not_null(true)
236 .with_true(true),
237 BaseType::Bool,
238 );
239 typer.ensure_base(where_, &t, BaseType::Bool);
240 }
241
242 let result = type_select_exprs(typer, &select.select_exprs, warn_duplicate);
243
244 if let Some((_, group_by)) = &select.group_by {
245 for e in group_by {
246 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
247 }
248 }
249
250 if let Some((_, order_by)) = &select.order_by {
251 for (e, _) in order_by {
252 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
253 }
254 }
255
256 if let Some((having, _)) = &select.having {
257 let t = type_expression(
258 typer,
259 having,
260 ExpressionFlags::default()
261 .with_not_null(true)
262 .with_true(true),
263 BaseType::Bool,
264 );
265 typer.ensure_base(having, &t, BaseType::Bool);
266 }
267
268 if let Some((_, offset, count)) = &select.limit {
269 if let Some(offset) = offset {
270 let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
271 if typer
272 .matched_type(&t, &FullType::new(Type::U64, true))
273 .is_none()
274 {
275 typer.err(format!("Expected integer type got {}", t.t), offset);
276 }
277 }
278 let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
279 if typer
280 .matched_type(&t, &FullType::new(Type::U64, true))
281 .is_none()
282 {
283 typer.err(format!("Expected integer type got {}", t.t), count);
284 }
285 }
286
287 SelectType {
288 columns: result
289 .into_iter()
290 .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
291 .collect(),
292 select_span: select.span(),
293 }
294}
295
296pub(crate) fn type_select_exprs<'a, 'b>(
297 typer: &mut Typer<'a, 'b>,
298 select_exprs: &[SelectExpr<'a>],
299 warn_duplicate: bool,
300) -> Vec<(Option<Identifier<'a>>, FullType<'a>, Span)> {
301 let mut result = Vec::new();
302 let mut select_reference = ReferenceType {
303 name: None,
304 span: select_exprs.opt_span().expect("select_exprs span"),
305 columns: Vec::new(),
306 };
307
308 for e in select_exprs {
309 let mut add_result = |issues: &mut Issues<'a>,
310 name: Option<Identifier<'a>>,
311 type_: FullType<'a>,
312 span: Span,
313 as_: bool| {
314 if let Some(name) = name.clone() {
315 if as_ {
316 select_reference.columns.push((name.clone(), type_.clone()));
317 }
318 for (on, _, os) in &result {
319 if Some(name.clone()) == *on && warn_duplicate {
320 issues
321 .warn("Also defined here", &span)
322 .frag(format!("Multiple columns with the name '{name}'"), os);
323 }
324 }
325 }
326 result.push((name, type_, span));
327 };
328 if let Expression::Identifier(parts) = &e.expr {
329 resolve_kleene_identifier(typer, parts, &e.as_, add_result);
330 } else {
331 let type_ = type_expression(typer, &e.expr, ExpressionFlags::default(), BaseType::Any);
332 if let Some(as_) = &e.as_ {
333 add_result(typer.issues, Some(as_.clone()), type_, as_.span(), true);
334 } else {
335 if typer.options.warn_unnamed_column_in_select {
336 typer.issues.warn("Unnamed column in select", e);
337 }
338 add_result(typer.issues, None, type_, 0..0, false);
339 };
340 }
341 }
342
343 typer.reference_types.push(select_reference);
344
345 result
346}
347
348pub(crate) fn type_union<'a>(typer: &mut Typer<'a, '_>, union: &Union<'a>) -> SelectType<'a> {
349 let mut t = type_union_select(typer, &union.left, true);
350 let mut left = union.left.span();
351 for w in &union.with {
352 let t2 = type_union_select(typer, &w.union_statement, true);
353
354 for i in 0..usize::max(t.columns.len(), t2.columns.len()) {
355 if let Some(l) = t.columns.get_mut(i) {
356 if let Some(r) = t2.columns.get(i) {
357 if l.name != r.name {
358 if let Some(ln) = &l.name {
359 if let Some(rn) = &r.name {
360 typer
361 .err("Incompatible names in union", &w.union_span)
362 .frag(format!("Column {i} is named {ln}"), &left)
363 .frag(format!("Column {i} is named {rn}"), &w.union_statement);
364 } else {
365 typer
366 .err("Incompatible names in union", &w.union_span)
367 .frag(format!("Column {i} is named {ln}"), &left)
368 .frag(format!("Column {i} has no name"), &w.union_statement);
369 }
370 } else {
371 typer
372 .err("Incompatible names in union", &w.union_span)
373 .frag(format!("Column {i} has no name"), &left)
374 .frag(
375 format!(
376 "Column {} is named {}",
377 i,
378 r.name.as_ref().expect("name")
379 ),
380 &w.union_statement,
381 );
382 }
383 }
384 if l.type_.t == r.type_.t {
385 l.type_ =
386 FullType::new(l.type_.t.clone(), l.type_.not_null && r.type_.not_null);
387 } else if let Some(t) = typer.matched_type(&l.type_, &r.type_) {
388 l.type_ = FullType::new(t, l.type_.not_null && r.type_.not_null);
389 } else {
390 typer
391 .err("Incompatible types in union", &w.union_span)
392 .frag(format!("Column {} is of type {}", i, l.type_.t), &left)
393 .frag(
394 format!("Column {} is of type {}", i, r.type_.t),
395 &w.union_statement,
396 );
397 }
398 } else if let Some(n) = &l.name {
399 typer
400 .err("Incompatible types in union", &w.union_span)
401 .frag(format!("Column {i} ({n}) only on this side"), &left);
402 } else {
403 typer
404 .err("Incompatible types in union", &w.union_span)
405 .frag(format!("Column {i} only on this side"), &left);
406 }
407 } else if let Some(n) = &t2.columns[i].name {
408 typer
409 .err("Incompatible types in union", &w.union_span)
410 .frag(
411 format!("Column {i} ({n}) only on this side"),
412 &w.union_statement,
413 );
414 } else {
415 typer
416 .err("Incompatible types in union", &w.union_span)
417 .frag(format!("Column {i} only on this side"), &w.union_statement);
418 }
419 }
420 left = left.join_span(&w.union_statement);
421 }
422
423 typer.reference_types.push(ReferenceType {
424 name: None,
425 span: t.span(),
426 columns: t
427 .columns
428 .iter()
429 .filter_map(|v| v.name.as_ref().map(|name| (name.clone(), v.type_.clone())))
430 .collect(),
431 });
432
433 if let Some((_, order_by)) = &union.order_by {
434 for (e, _) in order_by {
435 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
436 }
437 }
438
439 if let Some((_, offset, count)) = &union.limit {
440 if let Some(offset) = offset {
441 let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
442 if typer
443 .matched_type(&t, &FullType::new(Type::U64, true))
444 .is_none()
445 {
446 typer.err(format!("Expected integer type got {}", t.t), offset);
447 }
448 }
449 let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
450 if typer
451 .matched_type(&t, &FullType::new(Type::U64, true))
452 .is_none()
453 {
454 typer.err(format!("Expected integer type got {}", t.t), count);
455 }
456 }
457
458 typer.reference_types.pop();
459
460 t
461}
462
463pub(crate) fn type_union_select<'a>(
464 typer: &mut Typer<'a, '_>,
465 statement: &Statement<'a>,
466 warn_duplicate: bool,
467) -> SelectType<'a> {
468 match statement {
469 Statement::Select(s) => type_select(typer, s, warn_duplicate),
470 Statement::Union(u) => type_union(typer, u),
471 s => {
472 issue_ice!(typer.issues, s);
473 SelectType {
474 columns: Vec::new(),
475 select_span: s.span(),
476 }
477 }
478 }
479}