sql_type/
type_select.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::{format, vec::Vec};
14use sql_parse::{
15    issue_ice, issue_todo, Expression, Identifier, IdentifierPart, Issues, OptSpanned, Select,
16    SelectExpr, Span, Spanned, Statement, Union,
17};
18
19use crate::{
20    type_::{BaseType, FullType},
21    type_expression::{type_expression, ExpressionFlags},
22    type_reference::type_reference,
23    typer::{typer_stack, ReferenceType, Typer},
24    Type,
25};
26
27/// A column in select
28#[derive(Debug, Clone)]
29pub struct SelectTypeColumn<'a> {
30    /// The name of the column if one is specified or can be computed
31    pub name: Option<Identifier<'a>>,
32    /// The type of the data
33    pub type_: FullType<'a>,
34    /// A span of the expression yielding the column
35    pub span: Span,
36}
37
38impl<'a> Spanned for SelectTypeColumn<'a> {
39    fn span(&self) -> Span {
40        self.span.span()
41    }
42}
43
44#[derive(Debug, Clone)]
45pub(crate) struct SelectType<'a> {
46    pub columns: Vec<SelectTypeColumn<'a>>,
47    pub select_span: Span,
48}
49
50impl<'a> Spanned for SelectType<'a> {
51    fn span(&self) -> Span {
52        self.columns
53            .opt_span()
54            .unwrap_or_else(|| self.select_span.clone())
55    }
56}
57
58pub(crate) fn resolve_kleene_identifier<'a, 'b>(
59    typer: &mut Typer<'a, 'b>,
60    parts: &[IdentifierPart<'a>],
61    as_: &Option<Identifier<'a>>,
62    mut cb: impl FnMut(&mut Issues<'a>, Option<Identifier<'a>>, FullType<'a>, Span, bool),
63) {
64    match parts {
65        [sql_parse::IdentifierPart::Name(col)] => {
66            let mut cnt = 0;
67            let mut t = None;
68            for r in &typer.reference_types {
69                for c in &r.columns {
70                    if c.0 == *col {
71                        cnt += 1;
72                        t = Some(c);
73                    }
74                }
75            }
76            let name = as_.as_ref().unwrap_or(col);
77            if cnt > 1 {
78                let mut issue = typer.issues.err("Ambigious reference", col);
79                for r in &typer.reference_types {
80                    for c in &r.columns {
81                        if c.0 == *col {
82                            issue.frag("Defined here", &r.span);
83                        }
84                    }
85                }
86                cb(
87                    typer.issues,
88                    Some(name.clone()),
89                    FullType::invalid(),
90                    name.span(),
91                    as_.is_some(),
92                );
93            } else if let Some(t) = t {
94                cb(
95                    typer.issues,
96                    Some(name.clone()),
97                    t.1.clone(),
98                    name.span(),
99                    as_.is_some(),
100                );
101            } else {
102                typer.err("Unknown identifier", col);
103                cb(
104                    typer.issues,
105                    Some(name.clone()),
106                    FullType::invalid(),
107                    name.span(),
108                    as_.is_some(),
109                );
110            }
111        }
112        [sql_parse::IdentifierPart::Star(v)] => {
113            if let Some(as_) = as_ {
114                typer.err("As not supported for *", as_);
115            }
116            for r in &typer.reference_types {
117                for c in &r.columns {
118                    cb(
119                        typer.issues,
120                        Some(c.0.clone()),
121                        c.1.clone(),
122                        v.clone(),
123                        false,
124                    );
125                }
126            }
127        }
128        [sql_parse::IdentifierPart::Name(tbl), sql_parse::IdentifierPart::Name(col)] => {
129            let mut t = None;
130            for r in &typer.reference_types {
131                if r.name == Some(tbl.clone()) {
132                    for c in &r.columns {
133                        if c.0 == *col {
134                            t = Some(c);
135                        }
136                    }
137                }
138            }
139            let name = as_.as_ref().unwrap_or(col);
140            if let Some(t) = t {
141                cb(
142                    typer.issues,
143                    Some(name.clone()),
144                    t.1.clone(),
145                    name.span(),
146                    as_.is_some(),
147                );
148            } else {
149                typer.err("Unknown identifier", col);
150                cb(
151                    typer.issues,
152                    Some(name.clone()),
153                    FullType::invalid(),
154                    name.span(),
155                    as_.is_some(),
156                );
157            }
158        }
159        [sql_parse::IdentifierPart::Name(tbl), sql_parse::IdentifierPart::Star(v)] => {
160            if let Some(as_) = as_ {
161                typer.err("As not supported for *", as_);
162            }
163            let mut t = None;
164            for r in &typer.reference_types {
165                if r.name == Some(tbl.clone()) {
166                    t = Some(r);
167                }
168            }
169            if let Some(t) = t {
170                for c in &t.columns {
171                    cb(
172                        typer.issues,
173                        Some(c.0.clone()),
174                        c.1.clone(),
175                        v.clone(),
176                        false,
177                    );
178                }
179            } else {
180                typer.err("Unknown table", tbl);
181            }
182        }
183        [sql_parse::IdentifierPart::Star(v), _] => {
184            typer.err("Not supported here", v);
185        }
186        _ => {
187            typer.err("Invalid identifier", &parts.opt_span().expect("parts span"));
188        }
189    }
190}
191
192pub(crate) fn type_select<'a>(
193    typer: &mut Typer<'a, '_>,
194    select: &Select<'a>,
195    warn_duplicate: bool,
196) -> SelectType<'a> {
197    let mut guard = typer_stack(
198        typer,
199        |t| t.reference_types.clone(),
200        |t, v| t.reference_types = v,
201    );
202    let typer = &mut guard.typer;
203
204    for flag in &select.flags {
205        match &flag {
206            sql_parse::SelectFlag::All(_) => issue_todo!(typer.issues, flag),
207            sql_parse::SelectFlag::Distinct(_) | sql_parse::SelectFlag::DistinctRow(_) => (),
208            sql_parse::SelectFlag::StraightJoin(_) => issue_todo!(typer.issues, flag),
209            sql_parse::SelectFlag::HighPriority(_)
210            | sql_parse::SelectFlag::SqlSmallResult(_)
211            | sql_parse::SelectFlag::SqlBigResult(_)
212            | sql_parse::SelectFlag::SqlBufferResult(_)
213            | sql_parse::SelectFlag::SqlNoCache(_)
214            | sql_parse::SelectFlag::SqlCalcFoundRows(_) => (),
215        }
216    }
217
218    if let Some(references) = &select.table_references {
219        for reference in references {
220            type_reference(typer, reference, false);
221        }
222    }
223
224    if let Some((where_, _)) = &select.where_ {
225        let t = type_expression(
226            typer,
227            where_,
228            ExpressionFlags::default()
229                .with_not_null(true)
230                .with_true(true),
231            BaseType::Bool,
232        );
233        typer.ensure_base(where_, &t, BaseType::Bool);
234    }
235
236    let result = type_select_exprs(typer, &select.select_exprs, warn_duplicate);
237
238    if let Some((_, group_by)) = &select.group_by {
239        for e in group_by {
240            type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
241        }
242    }
243
244    if let Some((_, order_by)) = &select.order_by {
245        for (e, _) in order_by {
246            type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
247        }
248    }
249
250    if let Some((having, _)) = &select.having {
251        let t = type_expression(
252            typer,
253            having,
254            ExpressionFlags::default()
255                .with_not_null(true)
256                .with_true(true),
257            BaseType::Bool,
258        );
259        typer.ensure_base(having, &t, BaseType::Bool);
260    }
261
262    if let Some((_, offset, count)) = &select.limit {
263        if let Some(offset) = offset {
264            let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
265            if typer
266                .matched_type(&t, &FullType::new(Type::U64, true))
267                .is_none()
268            {
269                typer.err(format!("Expected integer type got {}", t.t), offset);
270            }
271        }
272        let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
273        if typer
274            .matched_type(&t, &FullType::new(Type::U64, true))
275            .is_none()
276        {
277            typer.err(format!("Expected integer type got {}", t.t), count);
278        }
279    }
280
281    SelectType {
282        columns: result
283            .into_iter()
284            .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
285            .collect(),
286        select_span: select.span(),
287    }
288}
289
290pub(crate) fn type_select_exprs<'a, 'b>(
291    typer: &mut Typer<'a, 'b>,
292    select_exprs: &[SelectExpr<'a>],
293    warn_duplicate: bool,
294) -> Vec<(Option<Identifier<'a>>, FullType<'a>, Span)> {
295    let mut result = Vec::new();
296    let mut select_reference = ReferenceType {
297        name: None,
298        span: select_exprs.opt_span().expect("select_exprs span"),
299        columns: Vec::new(),
300    };
301
302    for e in select_exprs {
303        let mut add_result = |issues: &mut Issues<'a>,
304                              name: Option<Identifier<'a>>,
305                              type_: FullType<'a>,
306                              span: Span,
307                              as_: bool| {
308            if let Some(name) = name.clone() {
309                if as_ {
310                    select_reference.columns.push((name.clone(), type_.clone()));
311                }
312                for (on, _, os) in &result {
313                    if Some(name.clone()) == *on && warn_duplicate {
314                        issues
315                            .warn("Also defined here", &span)
316                            .frag(format!("Multiple columns with the name '{}'", name), os);
317                    }
318                }
319            }
320            result.push((name, type_, span));
321        };
322        if let Expression::Identifier(parts) = &e.expr {
323            resolve_kleene_identifier(typer, parts, &e.as_, add_result);
324        } else {
325            let type_ = type_expression(typer, &e.expr, ExpressionFlags::default(), BaseType::Any);
326            if let Some(as_) = &e.as_ {
327                add_result(typer.issues, Some(as_.clone()), type_, as_.span(), true);
328            } else {
329                if typer.options.warn_unnamed_column_in_select {
330                    typer.issues.warn("Unnamed column in select", e);
331                }
332                add_result(typer.issues, None, type_, 0..0, false);
333            };
334        }
335    }
336
337    typer.reference_types.push(select_reference);
338
339    result
340}
341
342pub(crate) fn type_union<'a>(typer: &mut Typer<'a, '_>, union: &Union<'a>) -> SelectType<'a> {
343    let mut t = type_union_select(typer, &union.left, true);
344    let mut left = union.left.span();
345    for w in &union.with {
346        let t2 = type_union_select(typer, &w.union_statement, true);
347
348        for i in 0..usize::max(t.columns.len(), t2.columns.len()) {
349            if let Some(l) = t.columns.get_mut(i) {
350                if let Some(r) = t2.columns.get(i) {
351                    if l.name != r.name {
352                        if let Some(ln) = &l.name {
353                            if let Some(rn) = &r.name {
354                                typer
355                                    .err("Incompatible names in union", &w.union_span)
356                                    .frag(format!("Column {} is named {}", i, ln), &left)
357                                    .frag(
358                                        format!("Column {} is named {}", i, rn),
359                                        &w.union_statement,
360                                    );
361                            } else {
362                                typer
363                                    .err("Incompatible names in union", &w.union_span)
364                                    .frag(format!("Column {} is named {}", i, ln), &left)
365                                    .frag(format!("Column {} has no name", i), &w.union_statement);
366                            }
367                        } else {
368                            typer
369                                .err("Incompatible names in union", &w.union_span)
370                                .frag(format!("Column {} has no name", i), &left)
371                                .frag(
372                                    format!(
373                                        "Column {} is named {}",
374                                        i,
375                                        r.name.as_ref().expect("name")
376                                    ),
377                                    &w.union_statement,
378                                );
379                        }
380                    }
381                    if l.type_.t == r.type_.t {
382                        l.type_ =
383                            FullType::new(l.type_.t.clone(), l.type_.not_null && r.type_.not_null);
384                    } else if let Some(t) = typer.matched_type(&l.type_, &r.type_) {
385                        l.type_ = FullType::new(t, l.type_.not_null && r.type_.not_null);
386                    } else {
387                        typer
388                            .err("Incompatible types in union", &w.union_span)
389                            .frag(format!("Column {} is of type {}", i, l.type_.t), &left)
390                            .frag(
391                                format!("Column {} is of type {}", i, r.type_.t),
392                                &w.union_statement,
393                            );
394                    }
395                } else if let Some(n) = &l.name {
396                    typer
397                        .err("Incompatible types in union", &w.union_span)
398                        .frag(format!("Column {} ({}) only on this side", i, n), &left);
399                } else {
400                    typer
401                        .err("Incompatible types in union", &w.union_span)
402                        .frag(format!("Column {} only on this side", i), &left);
403                }
404            } else if let Some(n) = &t2.columns[i].name {
405                typer
406                    .err("Incompatible types in union", &w.union_span)
407                    .frag(
408                        format!("Column {} ({}) only on this side", i, n),
409                        &w.union_statement,
410                    );
411            } else {
412                typer
413                    .err("Incompatible types in union", &w.union_span)
414                    .frag(
415                        format!("Column {} only on this side", i),
416                        &w.union_statement,
417                    );
418            }
419        }
420        left = left.join_span(&w.union_statement);
421    }
422
423    typer.reference_types.push(ReferenceType {
424        name: None,
425        span: t.span(),
426        columns: t
427            .columns
428            .iter()
429            .filter_map(|v| v.name.as_ref().map(|name| (name.clone(), v.type_.clone())))
430            .collect(),
431    });
432
433    if let Some((_, order_by)) = &union.order_by {
434        for (e, _) in order_by {
435            type_expression(typer, e, ExpressionFlags::default(), BaseType::Any);
436        }
437    }
438
439    if let Some((_, offset, count)) = &union.limit {
440        if let Some(offset) = offset {
441            let t = type_expression(typer, offset, ExpressionFlags::default(), BaseType::Integer);
442            if typer
443                .matched_type(&t, &FullType::new(Type::U64, true))
444                .is_none()
445            {
446                typer.err(format!("Expected integer type got {}", t.t), offset);
447            }
448        }
449        let t = type_expression(typer, count, ExpressionFlags::default(), BaseType::Integer);
450        if typer
451            .matched_type(&t, &FullType::new(Type::U64, true))
452            .is_none()
453        {
454            typer.err(format!("Expected integer type got {}", t.t), count);
455        }
456    }
457
458    typer.reference_types.pop();
459
460    t
461}
462
463pub(crate) fn type_union_select<'a>(
464    typer: &mut Typer<'a, '_>,
465    statement: &Statement<'a>,
466    warn_duplicate: bool,
467) -> SelectType<'a> {
468    match statement {
469        Statement::Select(s) => type_select(typer, s, warn_duplicate),
470        Statement::Union(u) => type_union(typer, u),
471        s => {
472            issue_ice!(typer.issues, s);
473            SelectType {
474                columns: Vec::new(),
475                select_span: s.span(),
476            }
477        }
478    }
479}