sql_type/
type_update.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use sql_parse::{OptSpanned, Spanned, Update};
14
15use crate::{
16    SelectTypeColumn, Type,
17    type_::BaseType,
18    type_expression::{ExpressionFlags, type_expression},
19    type_reference::type_reference,
20    type_select::{SelectType, type_select_exprs},
21    typer::{Typer, typer_stack},
22};
23
24pub(crate) fn type_update<'a>(
25    typer: &mut Typer<'a, '_>,
26    update: &Update<'a>,
27) -> Option<SelectType<'a>> {
28    let mut guard = typer_stack(
29        typer,
30        |t| core::mem::take(&mut t.reference_types),
31        |t, v| t.reference_types = v,
32    );
33    let typer = &mut guard.typer;
34
35    for f in &update.flags {
36        match f {
37            sql_parse::UpdateFlag::LowPriority(_) | sql_parse::UpdateFlag::Ignore(_) => (),
38        }
39    }
40
41    for reference in &update.tables {
42        type_reference(typer, reference, false);
43    }
44
45    for (key, value) in &update.set {
46        let flags = ExpressionFlags::default();
47        match key.as_slice() {
48            [key] => {
49                let mut cnt = 0;
50                let mut t = None;
51                for r in &typer.reference_types {
52                    for c in &r.columns {
53                        if c.0 == *key {
54                            cnt += 1;
55                            t = Some(c.clone());
56                        }
57                    }
58                }
59                if cnt > 1 {
60                    type_expression(typer, value, flags, BaseType::Any);
61                    let mut issue = typer
62                        .issues
63                        .err("Ambiguous reference", &key.opt_span().unwrap());
64                    for r in &typer.reference_types {
65                        for c in &r.columns {
66                            if c.0 == *key {
67                                issue.frag("Defined here", &r.span);
68                            }
69                        }
70                    }
71                } else if let Some(t) = t {
72                    let value_type = type_expression(typer, value, flags, t.1.base());
73                    if typer.matched_type(&value_type, &t.1).is_none() {
74                        typer.err(
75                            alloc::format!("Got type {} expected {}", value_type, t.1),
76                            value,
77                        );
78                    } else if let Type::Args(_, args) = &value_type.t {
79                        for (idx, arg_type, _) in args.iter() {
80                            typer.constrain_arg(*idx, arg_type, &t.1);
81                        }
82                    }
83                } else {
84                    type_expression(typer, value, flags, BaseType::Any);
85                    typer
86                        .issues
87                        .err("Unknown identifier", &key.opt_span().unwrap());
88                }
89            }
90            [table, column] => {
91                let mut t = None;
92                for r in &typer.reference_types {
93                    if r.name != Some(table.clone()) {
94                        continue;
95                    }
96                    for c in &r.columns {
97                        if c.0 == column.clone() {
98                            t = Some(c.clone());
99                        }
100                    }
101                }
102                if let Some(t) = t {
103                    let value_type = type_expression(typer, value, flags, t.1.base());
104                    if typer.matched_type(&value_type, &t.1).is_none() {
105                        typer.err(
106                            alloc::format!("Got type {} expected {}", value_type, t.1),
107                            value,
108                        );
109                    } else if let Type::Args(_, args) = &value_type.t {
110                        for (idx, arg_type, _) in args.iter() {
111                            typer.constrain_arg(*idx, arg_type, &t.1);
112                        }
113                    }
114                } else {
115                    type_expression(typer, value, flags, BaseType::Any);
116                    typer
117                        .issues
118                        .err("Unknown identifier", &key.opt_span().unwrap());
119                }
120            }
121            _ => {
122                type_expression(typer, value, flags, BaseType::Any);
123                typer
124                    .issues
125                    .err("Unknown identifier", &key.opt_span().unwrap());
126            }
127        }
128    }
129
130    if let Some((where_, _)) = &update.where_ {
131        let t = type_expression(typer, where_, ExpressionFlags::default(), BaseType::Bool);
132        typer.ensure_base(where_, &t, BaseType::Bool);
133    }
134
135    match &update.returning {
136        Some((returning_span, returning_exprs)) => {
137            let columns = type_select_exprs(typer, returning_exprs, true)
138                .into_iter()
139                .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
140                .collect();
141            Some(SelectType {
142                columns,
143                select_span: returning_span.join_span(returning_exprs),
144            })
145        }
146        None => None,
147    }
148}