1use alloc::{format, vec::Vec};
14use sql_parse::{
15 issue_ice, issue_todo, Expression, Identifier, IdentifierPart, Issues, OptSpanned, Select,
16 SelectExpr, Span, Spanned, Statement, Union,
17};
18
19use crate::{
20 type_::{BaseType, FullType},
21 type_expression::{type_expression, ExpressionFlags},
22 type_reference::type_reference,
23 typer::{typer_stack, ReferenceType, Typer},
24 Type,
25};
26
27#[derive(Debug, Clone)]
29pub struct SelectTypeColumn<'a> {
30 pub name: Option<Identifier<'a>>,
32 pub type_: FullType<'a>,
34 pub span: Span,
36}
37
38impl<'a> Spanned for SelectTypeColumn<'a> {
39 fn span(&self) -> Span {
40 self.span.span()
41 }
42}
43
44#[derive(Debug, Clone)]
45pub(crate) struct SelectType<'a> {
46 pub columns: Vec<SelectTypeColumn<'a>>,
47 pub select_span: Span,
48}
49
50impl<'a> Spanned for SelectType<'a> {
51 fn span(&self) -> Span {
52 self.columns
53 .opt_span()
54 .unwrap_or_else(|| self.select_span.clone())
55 }
56}
57
58pub(crate) fn resolve_kleene_identifier<'a, 'b>(
59 typer: &mut Typer<'a, 'b>,
60 parts: &[IdentifierPart<'a>],
61 as_: &Option<Identifier<'a>>,
62 mut cb: impl FnMut(&mut Issues<'a>, Option<Identifier<'a>>, FullType<'a>, Span, bool),
63) {
64 match parts {
65 [sql_parse::IdentifierPart::Name(col)] => {
66 let mut cnt = 0;
67 let mut t = None;
68 for r in &typer.reference_types {
69 for c in &r.columns {
70 if c.0 == *col {
71 cnt += 1;
72 t = Some(c);
73 }
74 }
75 }
76 let name = as_.as_ref().unwrap_or(col);
77 if cnt > 1 {
78 let mut issue = typer.issues.err("Ambigious reference", col);
79 for r in &typer.reference_types {
80 for c in &r.columns {
81 if c.0 == *col {
82 issue.frag("Defined here", &r.span);
83 }
84 }
85 }
86 cb(
87 typer.issues,
88 Some(name.clone()),
89 FullType::invalid(),
90 name.span(),
91 as_.is_some(),
92 );
93 } else if let Some(t) = t {
94 cb(
95 typer.issues,
96 Some(name.clone()),
97 t.1.clone(),
98 name.span(),
99 as_.is_some(),
100 );
101 } else {
102 typer.err("Unknown identifier", col);
103 cb(
104 typer.issues,
105 Some(name.clone()),
106 FullType::invalid(),
107 name.span(),
108 as_.is_some(),
109 );
110 }
111 }
112 [sql_parse::IdentifierPart::Star(v)] => {
113 if let Some(as_) = as_ {
114 typer.err("As not supported for *", as_);
115 }
116 for r in &typer.reference_types {
117 for c in &r.columns {
118 cb(
119 typer.issues,
120 Some(c.0.clone()),
121 c.1.clone(),
122 v.clone(),
123 false,
124 );
125 }
126 }
127 }
128 [sql_parse::IdentifierPart::Name(tbl), sql_parse::IdentifierPart::Name(col)] => {
129 let mut t = None;
130 for r in &typer.reference_types {
131 if r.name == Some(tbl.clone()) {
132 for c in &r.columns {
133 if c.0 == *col {
134 t = Some(c);
135 }
136 }
137 }
138 }
139 let name = as_.as_ref().unwrap_or(col);
140 if let Some(t) = t {
141 cb(
142 typer.issues,
143 Some(name.clone()),
144 t.1.clone(),
145 name.span(),
146 as_.is_some(),
147 );
148 } else {
149 typer.err("Unknown identifier", col);
150 cb(
151 typer.issues,
152 Some(name.clone()),
153 FullType::invalid(),
154 name.span(),
155 as_.is_some(),
156 );
157 }
158 }
159 [sql_parse::IdentifierPart::Name(tbl), sql_parse::IdentifierPart::Star(v)] => {
160 if let Some(as_) = as_ {
161 typer.err("As not supported for *", as_);
162 }
163 let mut t = None;
164 for r in &typer.reference_types {
165 if r.name == Some(tbl.clone()) {
166 t = Some(r);
167 }
168 }
169 if let Some(t) = t {
170 for c in &t.columns {
171 cb(
172 typer.issues,
173 Some(c.0.clone()),
174 c.1.clone(),
175 v.clone(),
176 false,
177 );
178 }
179 } else {
180 typer.err("Unknown table", tbl);
181 }
182 }
183 [sql_parse::IdentifierPart::Star(v), _] => {
184 typer.err("Not supported here", v);
185 }
186 _ => {
187 typer.err("Invalid identifier", &parts.opt_span().expect("parts span"));
188 }
189 }
190}
191
192pub(crate) fn type_select<'a>(
193 typer: &mut Typer<'a, '_>,
194 select: &Select<'a>,
195 warn_duplicate: bool,
196) -> SelectType<'a> {
197 let mut guard = typer_stack(
198 typer,
199 |t| t.reference_types.clone(),
200 |t, v| t.reference_types = v,
201 );
202 let typer = &mut guard.typer;
203
204 for flag in &select.flags {
205 match &flag {
206 sql_parse::SelectFlag::All(_) => issue_todo!(typer.issues, flag),
207 sql_parse::SelectFlag::Distinct(_) | sql_parse::SelectFlag::DistinctRow(_) => (),
208 sql_parse::SelectFlag::StraightJoin(_) => issue_todo!(typer.issues, flag),
209 sql_parse::SelectFlag::HighPriority(_)
210 | sql_parse::SelectFlag::SqlSmallResult(_)
211 | sql_parse::SelectFlag::SqlBigResult(_)
212 | sql_parse::SelectFlag::SqlBufferResult(_)
213 | sql_parse::SelectFlag::SqlNoCache(_)
214 | sql_parse::SelectFlag::SqlCalcFoundRows(_) => (),
215 }
216 }
217
218 if let Some(references) = &select.table_references {
219 for reference in references {
220 type_reference(typer, reference, false);
221 }
222 }
223
224 if let Some((where_, _)) = &select.where_ {
225 let t = type_expression(
226 typer,
227 where_,
228 ExpressionFlags::default()
229 .with_not_null(true)
230 .with_true(true),
231 BaseType::Bool,
232 );
233 typer.ensure_base(where_, &t, BaseType::Bool);
234 }
235
236 let result = type_select_exprs(typer, &select.select_exprs, warn_duplicate);
237
238 if let Some((_, group_by)) = &select.group_by {
239 for e in group_by {
240 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
241 }
242 }
243
244 if let Some((_, order_by)) = &select.order_by {
245 for (e, _) in order_by {
246 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
247 }
248 }
249
250 if let Some((having, _)) = &select.having {
251 let t = type_expression(
252 typer,
253 having,
254 ExpressionFlags::default()
255 .with_not_null(true)
256 .with_true(true),
257 BaseType::Bool,
258 );
259 typer.ensure_base(having, &t, BaseType::Bool);
260 }
261
262 if let Some((_, offset, count)) = &select.limit {
263 if let Some(offset) = offset {
264 let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
265 if typer
266 .matched_type(&t, &FullType::new(Type::U64, true))
267 .is_none()
268 {
269 typer.err(format!("Expected integer type got {}", t.t), offset);
270 }
271 }
272 let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
273 if typer
274 .matched_type(&t, &FullType::new(Type::U64, true))
275 .is_none()
276 {
277 typer.err(format!("Expected integer type got {}", t.t), count);
278 }
279 }
280
281 SelectType {
282 columns: result
283 .into_iter()
284 .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
285 .collect(),
286 select_span: select.span(),
287 }
288}
289
290pub(crate) fn type_select_exprs<'a, 'b>(
291 typer: &mut Typer<'a, 'b>,
292 select_exprs: &[SelectExpr<'a>],
293 warn_duplicate: bool,
294) -> Vec<(Option<Identifier<'a>>, FullType<'a>, Span)> {
295 let mut result = Vec::new();
296 let mut select_reference = ReferenceType {
297 name: None,
298 span: select_exprs.opt_span().expect("select_exprs span"),
299 columns: Vec::new(),
300 };
301
302 for e in select_exprs {
303 let mut add_result = |issues: &mut Issues<'a>,
304 name: Option<Identifier<'a>>,
305 type_: FullType<'a>,
306 span: Span,
307 as_: bool| {
308 if let Some(name) = name.clone() {
309 if as_ {
310 select_reference.columns.push((name.clone(), type_.clone()));
311 }
312 for (on, _, os) in &result {
313 if Some(name.clone()) == *on && warn_duplicate {
314 issues
315 .warn("Also defined here", &span)
316 .frag(format!("Multiple columns with the name '{name}'"), os);
317 }
318 }
319 }
320 result.push((name, type_, span));
321 };
322 if let Expression::Identifier(parts) = &e.expr {
323 resolve_kleene_identifier(typer, parts, &e.as_, add_result);
324 } else {
325 let type_ = type_expression(typer, &e.expr, ExpressionFlags::default(), BaseType::Any);
326 if let Some(as_) = &e.as_ {
327 add_result(typer.issues, Some(as_.clone()), type_, as_.span(), true);
328 } else {
329 if typer.options.warn_unnamed_column_in_select {
330 typer.issues.warn("Unnamed column in select", e);
331 }
332 add_result(typer.issues, None, type_, 0..0, false);
333 };
334 }
335 }
336
337 typer.reference_types.push(select_reference);
338
339 result
340}
341
342pub(crate) fn type_union<'a>(typer: &mut Typer<'a, '_>, union: &Union<'a>) -> SelectType<'a> {
343 let mut t = type_union_select(typer, &union.left, true);
344 let mut left = union.left.span();
345 for w in &union.with {
346 let t2 = type_union_select(typer, &w.union_statement, true);
347
348 for i in 0..usize::max(t.columns.len(), t2.columns.len()) {
349 if let Some(l) = t.columns.get_mut(i) {
350 if let Some(r) = t2.columns.get(i) {
351 if l.name != r.name {
352 if let Some(ln) = &l.name {
353 if let Some(rn) = &r.name {
354 typer
355 .err("Incompatible names in union", &w.union_span)
356 .frag(format!("Column {i} is named {ln}"), &left)
357 .frag(format!("Column {i} is named {rn}"), &w.union_statement);
358 } else {
359 typer
360 .err("Incompatible names in union", &w.union_span)
361 .frag(format!("Column {i} is named {ln}"), &left)
362 .frag(format!("Column {i} has no name"), &w.union_statement);
363 }
364 } else {
365 typer
366 .err("Incompatible names in union", &w.union_span)
367 .frag(format!("Column {i} has no name"), &left)
368 .frag(
369 format!(
370 "Column {} is named {}",
371 i,
372 r.name.as_ref().expect("name")
373 ),
374 &w.union_statement,
375 );
376 }
377 }
378 if l.type_.t == r.type_.t {
379 l.type_ =
380 FullType::new(l.type_.t.clone(), l.type_.not_null && r.type_.not_null);
381 } else if let Some(t) = typer.matched_type(&l.type_, &r.type_) {
382 l.type_ = FullType::new(t, l.type_.not_null && r.type_.not_null);
383 } else {
384 typer
385 .err("Incompatible types in union", &w.union_span)
386 .frag(format!("Column {} is of type {}", i, l.type_.t), &left)
387 .frag(
388 format!("Column {} is of type {}", i, r.type_.t),
389 &w.union_statement,
390 );
391 }
392 } else if let Some(n) = &l.name {
393 typer
394 .err("Incompatible types in union", &w.union_span)
395 .frag(format!("Column {i} ({n}) only on this side"), &left);
396 } else {
397 typer
398 .err("Incompatible types in union", &w.union_span)
399 .frag(format!("Column {i} only on this side"), &left);
400 }
401 } else if let Some(n) = &t2.columns[i].name {
402 typer
403 .err("Incompatible types in union", &w.union_span)
404 .frag(
405 format!("Column {i} ({n}) only on this side"),
406 &w.union_statement,
407 );
408 } else {
409 typer
410 .err("Incompatible types in union", &w.union_span)
411 .frag(format!("Column {i} only on this side"), &w.union_statement);
412 }
413 }
414 left = left.join_span(&w.union_statement);
415 }
416
417 typer.reference_types.push(ReferenceType {
418 name: None,
419 span: t.span(),
420 columns: t
421 .columns
422 .iter()
423 .filter_map(|v| v.name.as_ref().map(|name| (name.clone(), v.type_.clone())))
424 .collect(),
425 });
426
427 if let Some((_, order_by)) = &union.order_by {
428 for (e, _) in order_by {
429 type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
430 }
431 }
432
433 if let Some((_, offset, count)) = &union.limit {
434 if let Some(offset) = offset {
435 let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
436 if typer
437 .matched_type(&t, &FullType::new(Type::U64, true))
438 .is_none()
439 {
440 typer.err(format!("Expected integer type got {}", t.t), offset);
441 }
442 }
443 let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
444 if typer
445 .matched_type(&t, &FullType::new(Type::U64, true))
446 .is_none()
447 {
448 typer.err(format!("Expected integer type got {}", t.t), count);
449 }
450 }
451
452 typer.reference_types.pop();
453
454 t
455}
456
457pub(crate) fn type_union_select<'a>(
458 typer: &mut Typer<'a, '_>,
459 statement: &Statement<'a>,
460 warn_duplicate: bool,
461) -> SelectType<'a> {
462 match statement {
463 Statement::Select(s) => type_select(typer, s, warn_duplicate),
464 Statement::Union(u) => type_union(typer, u),
465 s => {
466 issue_ice!(typer.issues, s);
467 SelectType {
468 columns: Vec::new(),
469 select_span: s.span(),
470 }
471 }
472 }
473}