sql_type/
type_select.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::{format, vec::Vec};
14use sql_parse::{
15    Expression, Identifier, IdentifierPart, Issues, OptSpanned, Select, SelectExpr, Span, Spanned,
16    Statement, Union, issue_ice, issue_todo,
17};
18
19use crate::{
20    Type,
21    type_::{BaseType, FullType},
22    type_expression::{ExpressionFlags, type_expression},
23    type_reference::type_reference,
24    typer::{ReferenceType, Typer, typer_stack},
25};
26
27/// A column in select
28#[derive(Debug, Clone)]
29pub struct SelectTypeColumn<'a> {
30    /// The name of the column if one is specified or can be computed
31    pub name: Option<Identifier<'a>>,
32    /// The type of the data
33    pub type_: FullType<'a>,
34    /// A span of the expression yielding the column
35    pub span: Span,
36}
37
38impl<'a> Spanned for SelectTypeColumn<'a> {
39    fn span(&self) -> Span {
40        self.span.span()
41    }
42}
43
44#[derive(Debug, Clone)]
45pub(crate) struct SelectType<'a> {
46    pub columns: Vec<SelectTypeColumn<'a>>,
47    pub select_span: Span,
48}
49
50impl<'a> Spanned for SelectType<'a> {
51    fn span(&self) -> Span {
52        self.columns
53            .opt_span()
54            .unwrap_or_else(|| self.select_span.clone())
55    }
56}
57
58pub(crate) fn resolve_kleene_identifier<'a, 'b>(
59    typer: &mut Typer<'a, 'b>,
60    parts: &[IdentifierPart<'a>],
61    as_: &Option<Identifier<'a>>,
62    mut cb: impl FnMut(&mut Issues<'a>, Option<Identifier<'a>>, FullType<'a>, Span, bool),
63) {
64    match parts {
65        [sql_parse::IdentifierPart::Name(col)] => {
66            let mut cnt = 0;
67            let mut t = None;
68            for r in &typer.reference_types {
69                for c in &r.columns {
70                    if c.0 == *col {
71                        cnt += 1;
72                        t = Some(c);
73                    }
74                }
75            }
76            let name = as_.as_ref().unwrap_or(col);
77            if cnt > 1 {
78                let mut issue = typer.issues.err("Ambigious reference", col);
79                for r in &typer.reference_types {
80                    for c in &r.columns {
81                        if c.0 == *col {
82                            issue.frag("Defined here", &r.span);
83                        }
84                    }
85                }
86                cb(
87                    typer.issues,
88                    Some(name.clone()),
89                    FullType::invalid(),
90                    name.span(),
91                    as_.is_some(),
92                );
93            } else if let Some(t) = t {
94                cb(
95                    typer.issues,
96                    Some(name.clone()),
97                    t.1.clone(),
98                    name.span(),
99                    as_.is_some(),
100                );
101            } else {
102                typer.err("Unknown identifier", col);
103                cb(
104                    typer.issues,
105                    Some(name.clone()),
106                    FullType::invalid(),
107                    name.span(),
108                    as_.is_some(),
109                );
110            }
111        }
112        [sql_parse::IdentifierPart::Star(v)] => {
113            if let Some(as_) = as_ {
114                typer.err("As not supported for *", as_);
115            }
116            for r in &typer.reference_types {
117                for c in &r.columns {
118                    cb(
119                        typer.issues,
120                        Some(c.0.clone()),
121                        c.1.clone(),
122                        v.clone(),
123                        false,
124                    );
125                }
126            }
127        }
128        [
129            sql_parse::IdentifierPart::Name(tbl),
130            sql_parse::IdentifierPart::Name(col),
131        ] => {
132            let mut t = None;
133            for r in &typer.reference_types {
134                if r.name == Some(tbl.clone()) {
135                    for c in &r.columns {
136                        if c.0 == *col {
137                            t = Some(c);
138                        }
139                    }
140                }
141            }
142            let name = as_.as_ref().unwrap_or(col);
143            if let Some(t) = t {
144                cb(
145                    typer.issues,
146                    Some(name.clone()),
147                    t.1.clone(),
148                    name.span(),
149                    as_.is_some(),
150                );
151            } else {
152                typer.err("Unknown identifier", col);
153                cb(
154                    typer.issues,
155                    Some(name.clone()),
156                    FullType::invalid(),
157                    name.span(),
158                    as_.is_some(),
159                );
160            }
161        }
162        [
163            sql_parse::IdentifierPart::Name(tbl),
164            sql_parse::IdentifierPart::Star(v),
165        ] => {
166            if let Some(as_) = as_ {
167                typer.err("As not supported for *", as_);
168            }
169            let mut t = None;
170            for r in &typer.reference_types {
171                if r.name == Some(tbl.clone()) {
172                    t = Some(r);
173                }
174            }
175            if let Some(t) = t {
176                for c in &t.columns {
177                    cb(
178                        typer.issues,
179                        Some(c.0.clone()),
180                        c.1.clone(),
181                        v.clone(),
182                        false,
183                    );
184                }
185            } else {
186                typer.err("Unknown table", tbl);
187            }
188        }
189        [sql_parse::IdentifierPart::Star(v), _] => {
190            typer.err("Not supported here", v);
191        }
192        _ => {
193            typer.err("Invalid identifier", &parts.opt_span().expect("parts span"));
194        }
195    }
196}
197
198pub(crate) fn type_select<'a>(
199    typer: &mut Typer<'a, '_>,
200    select: &Select<'a>,
201    warn_duplicate: bool,
202) -> SelectType<'a> {
203    let mut guard = typer_stack(
204        typer,
205        |t| t.reference_types.clone(),
206        |t, v| t.reference_types = v,
207    );
208    let typer = &mut guard.typer;
209
210    for flag in &select.flags {
211        match &flag {
212            sql_parse::SelectFlag::All(_) => issue_todo!(typer.issues, flag),
213            sql_parse::SelectFlag::Distinct(_) | sql_parse::SelectFlag::DistinctRow(_) => (),
214            sql_parse::SelectFlag::StraightJoin(_) => issue_todo!(typer.issues, flag),
215            sql_parse::SelectFlag::HighPriority(_)
216            | sql_parse::SelectFlag::SqlSmallResult(_)
217            | sql_parse::SelectFlag::SqlBigResult(_)
218            | sql_parse::SelectFlag::SqlBufferResult(_)
219            | sql_parse::SelectFlag::SqlNoCache(_)
220            | sql_parse::SelectFlag::SqlCalcFoundRows(_) => (),
221        }
222    }
223
224    if let Some(references) = &select.table_references {
225        for reference in references {
226            type_reference(typer, reference, false);
227        }
228    }
229
230    if let Some((where_, _)) = &select.where_ {
231        let t = type_expression(
232            typer,
233            where_,
234            ExpressionFlags::default()
235                .with_not_null(true)
236                .with_true(true),
237            BaseType::Bool,
238        );
239        typer.ensure_base(where_, &t, BaseType::Bool);
240    }
241
242    let result = type_select_exprs(typer, &select.select_exprs, warn_duplicate);
243
244    if let Some((_, group_by)) = &select.group_by {
245        for e in group_by {
246            type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
247        }
248    }
249
250    if let Some((_, order_by)) = &select.order_by {
251        for (e, _) in order_by {
252            type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
253        }
254    }
255
256    if let Some((having, _)) = &select.having {
257        let t = type_expression(
258            typer,
259            having,
260            ExpressionFlags::default()
261                .with_not_null(true)
262                .with_true(true),
263            BaseType::Bool,
264        );
265        typer.ensure_base(having, &t, BaseType::Bool);
266    }
267
268    if let Some((_, offset, count)) = &select.limit {
269        if let Some(offset) = offset {
270            let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
271            if typer
272                .matched_type(&t, &FullType::new(Type::U64, true))
273                .is_none()
274            {
275                typer.err(format!("Expected integer type got {}", t.t), offset);
276            }
277        }
278        let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
279        if typer
280            .matched_type(&t, &FullType::new(Type::U64, true))
281            .is_none()
282        {
283            typer.err(format!("Expected integer type got {}", t.t), count);
284        }
285    }
286
287    SelectType {
288        columns: result
289            .into_iter()
290            .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
291            .collect(),
292        select_span: select.span(),
293    }
294}
295
296pub(crate) fn type_select_exprs<'a, 'b>(
297    typer: &mut Typer<'a, 'b>,
298    select_exprs: &[SelectExpr<'a>],
299    warn_duplicate: bool,
300) -> Vec<(Option<Identifier<'a>>, FullType<'a>, Span)> {
301    let mut result = Vec::new();
302    let mut select_reference = ReferenceType {
303        name: None,
304        span: select_exprs.opt_span().expect("select_exprs span"),
305        columns: Vec::new(),
306    };
307
308    for e in select_exprs {
309        let mut add_result = |issues: &mut Issues<'a>,
310                              name: Option<Identifier<'a>>,
311                              type_: FullType<'a>,
312                              span: Span,
313                              as_: bool| {
314            if let Some(name) = name.clone() {
315                if as_ {
316                    select_reference.columns.push((name.clone(), type_.clone()));
317                }
318                for (on, _, os) in &result {
319                    if Some(name.clone()) == *on && warn_duplicate {
320                        issues
321                            .warn("Also defined here", &span)
322                            .frag(format!("Multiple columns with the name '{name}'"), os);
323                    }
324                }
325            }
326            result.push((name, type_, span));
327        };
328        if let Expression::Identifier(parts) = &e.expr {
329            resolve_kleene_identifier(typer, parts, &e.as_, add_result);
330        } else {
331            let type_ = type_expression(typer, &e.expr, ExpressionFlags::default(), BaseType::Any);
332            if let Some(as_) = &e.as_ {
333                add_result(typer.issues, Some(as_.clone()), type_, as_.span(), true);
334            } else {
335                if typer.options.warn_unnamed_column_in_select {
336                    typer.issues.warn("Unnamed column in select", e);
337                }
338                add_result(typer.issues, None, type_, 0..0, false);
339            };
340        }
341    }
342
343    typer.reference_types.push(select_reference);
344
345    result
346}
347
348pub(crate) fn type_union<'a>(typer: &mut Typer<'a, '_>, union: &Union<'a>) -> SelectType<'a> {
349    let mut t = type_union_select(typer, &union.left, true);
350    let mut left = union.left.span();
351    for w in &union.with {
352        let t2 = type_union_select(typer, &w.union_statement, true);
353
354        for i in 0..usize::max(t.columns.len(), t2.columns.len()) {
355            if let Some(l) = t.columns.get_mut(i) {
356                if let Some(r) = t2.columns.get(i) {
357                    if l.name != r.name {
358                        if let Some(ln) = &l.name {
359                            if let Some(rn) = &r.name {
360                                typer
361                                    .err("Incompatible names in union", &w.union_span)
362                                    .frag(format!("Column {i} is named {ln}"), &left)
363                                    .frag(format!("Column {i} is named {rn}"), &w.union_statement);
364                            } else {
365                                typer
366                                    .err("Incompatible names in union", &w.union_span)
367                                    .frag(format!("Column {i} is named {ln}"), &left)
368                                    .frag(format!("Column {i} has no name"), &w.union_statement);
369                            }
370                        } else {
371                            typer
372                                .err("Incompatible names in union", &w.union_span)
373                                .frag(format!("Column {i} has no name"), &left)
374                                .frag(
375                                    format!(
376                                        "Column {} is named {}",
377                                        i,
378                                        r.name.as_ref().expect("name")
379                                    ),
380                                    &w.union_statement,
381                                );
382                        }
383                    }
384                    if l.type_.t == r.type_.t {
385                        l.type_ =
386                            FullType::new(l.type_.t.clone(), l.type_.not_null && r.type_.not_null);
387                    } else if let Some(t) = typer.matched_type(&l.type_, &r.type_) {
388                        l.type_ = FullType::new(t, l.type_.not_null && r.type_.not_null);
389                    } else {
390                        typer
391                            .err("Incompatible types in union", &w.union_span)
392                            .frag(format!("Column {} is of type {}", i, l.type_.t), &left)
393                            .frag(
394                                format!("Column {} is of type {}", i, r.type_.t),
395                                &w.union_statement,
396                            );
397                    }
398                } else if let Some(n) = &l.name {
399                    typer
400                        .err("Incompatible types in union", &w.union_span)
401                        .frag(format!("Column {i} ({n}) only on this side"), &left);
402                } else {
403                    typer
404                        .err("Incompatible types in union", &w.union_span)
405                        .frag(format!("Column {i} only on this side"), &left);
406                }
407            } else if let Some(n) = &t2.columns[i].name {
408                typer
409                    .err("Incompatible types in union", &w.union_span)
410                    .frag(
411                        format!("Column {i} ({n}) only on this side"),
412                        &w.union_statement,
413                    );
414            } else {
415                typer
416                    .err("Incompatible types in union", &w.union_span)
417                    .frag(format!("Column {i} only on this side"), &w.union_statement);
418            }
419        }
420        left = left.join_span(&w.union_statement);
421    }
422
423    typer.reference_types.push(ReferenceType {
424        name: None,
425        span: t.span(),
426        columns: t
427            .columns
428            .iter()
429            .filter_map(|v| v.name.as_ref().map(|name| (name.clone(), v.type_.clone())))
430            .collect(),
431    });
432
433    if let Some((_, order_by)) = &union.order_by {
434        for (e, _) in order_by {
435            type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
436        }
437    }
438
439    if let Some((_, offset, count)) = &union.limit {
440        if let Some(offset) = offset {
441            let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
442            if typer
443                .matched_type(&t, &FullType::new(Type::U64, true))
444                .is_none()
445            {
446                typer.err(format!("Expected integer type got {}", t.t), offset);
447            }
448        }
449        let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
450        if typer
451            .matched_type(&t, &FullType::new(Type::U64, true))
452            .is_none()
453        {
454            typer.err(format!("Expected integer type got {}", t.t), count);
455        }
456    }
457
458    typer.reference_types.pop();
459
460    t
461}
462
463pub(crate) fn type_union_select<'a>(
464    typer: &mut Typer<'a, '_>,
465    statement: &Statement<'a>,
466    warn_duplicate: bool,
467) -> SelectType<'a> {
468    match statement {
469        Statement::Select(s) => type_select(typer, s, warn_duplicate),
470        Statement::Union(u) => type_union(typer, u),
471        s => {
472            issue_ice!(typer.issues, s);
473            SelectType {
474                columns: Vec::new(),
475                select_span: s.span(),
476            }
477        }
478    }
479}