1use sql_parse::{OptSpanned, Spanned, Update};
14
15use crate::{
16 SelectTypeColumn, Type,
17 type_::BaseType,
18 type_expression::{ExpressionFlags, type_expression},
19 type_reference::type_reference,
20 type_select::{SelectType, type_select_exprs},
21 typer::{Typer, typer_stack},
22};
23
24pub(crate) fn type_update<'a>(
25 typer: &mut Typer<'a, '_>,
26 update: &Update<'a>,
27) -> Option<SelectType<'a>> {
28 let mut guard = typer_stack(
29 typer,
30 |t| core::mem::take(&mut t.reference_types),
31 |t, v| t.reference_types = v,
32 );
33 let typer = &mut guard.typer;
34
35 for f in &update.flags {
36 match f {
37 sql_parse::UpdateFlag::LowPriority(_) | sql_parse::UpdateFlag::Ignore(_) => (),
38 }
39 }
40
41 for reference in &update.tables {
42 type_reference(typer, reference, false);
43 }
44
45 for (key, value) in &update.set {
46 let flags = ExpressionFlags::default();
47 match key.as_slice() {
48 [key] => {
49 let mut cnt = 0;
50 let mut t = None;
51 for r in &typer.reference_types {
52 for c in &r.columns {
53 if c.0 == *key {
54 cnt += 1;
55 t = Some(c.clone());
56 }
57 }
58 }
59 if cnt > 1 {
60 type_expression(typer, value, flags, BaseType::Any);
61 let mut issue = typer
62 .issues
63 .err("Ambiguous reference", &key.opt_span().unwrap());
64 for r in &typer.reference_types {
65 for c in &r.columns {
66 if c.0 == *key {
67 issue.frag("Defined here", &r.span);
68 }
69 }
70 }
71 } else if let Some(t) = t {
72 let value_type = type_expression(typer, value, flags, t.1.base());
73 if typer.matched_type(&value_type, &t.1).is_none() {
74 typer.err(
75 alloc::format!("Got type {} expected {}", value_type, t.1),
76 value,
77 );
78 } else if let Type::Args(_, args) = &value_type.t {
79 for (idx, arg_type, _) in args.iter() {
80 typer.constrain_arg(*idx, arg_type, &t.1);
81 }
82 }
83 } else {
84 type_expression(typer, value, flags, BaseType::Any);
85 typer
86 .issues
87 .err("Unknown identifier", &key.opt_span().unwrap());
88 }
89 }
90 [table, column] => {
91 let mut t = None;
92 for r in &typer.reference_types {
93 if r.name != Some(table.clone()) {
94 continue;
95 }
96 for c in &r.columns {
97 if c.0 == column.clone() {
98 t = Some(c.clone());
99 }
100 }
101 }
102 if let Some(t) = t {
103 let value_type = type_expression(typer, value, flags, t.1.base());
104 if typer.matched_type(&value_type, &t.1).is_none() {
105 typer.err(
106 alloc::format!("Got type {} expected {}", value_type, t.1),
107 value,
108 );
109 } else if let Type::Args(_, args) = &value_type.t {
110 for (idx, arg_type, _) in args.iter() {
111 typer.constrain_arg(*idx, arg_type, &t.1);
112 }
113 }
114 } else {
115 type_expression(typer, value, flags, BaseType::Any);
116 typer
117 .issues
118 .err("Unknown identifier", &key.opt_span().unwrap());
119 }
120 }
121 _ => {
122 type_expression(typer, value, flags, BaseType::Any);
123 typer
124 .issues
125 .err("Unknown identifier", &key.opt_span().unwrap());
126 }
127 }
128 }
129
130 if let Some((where_, _)) = &update.where_ {
131 let t = type_expression(typer, where_, ExpressionFlags::default(), BaseType::Bool);
132 typer.ensure_base(where_, &t, BaseType::Bool);
133 }
134
135 match &update.returning {
136 Some((returning_span, returning_exprs)) => {
137 let columns = type_select_exprs(typer, returning_exprs, true)
138 .into_iter()
139 .map(|(name, type_, span)| SelectTypeColumn { name, type_, span })
140 .collect();
141 Some(SelectType {
142 columns,
143 select_span: returning_span.join_span(returning_exprs),
144 })
145 }
146 None => None,
147 }
148}