1use alloc::{format, vec::Vec};
14use sql_parse::{Expression, Function, Span};
15
16use crate::{
17 type_::{BaseType, FullType},
18 type_expression::{type_expression, ExpressionFlags},
19 typer::Typer,
20 Type,
21};
22
23fn arg_cnt<'a>(
24 typer: &mut Typer<'a, '_>,
25 rng: core::ops::Range<usize>,
26 args: &[Expression<'a>],
27 span: &Span,
28) {
29 if args.len() >= rng.start && args.len() <= rng.end {
30 return;
31 }
32
33 let mut issue = if rng.is_empty() {
34 typer.err(
35 format!("Expected {} arguments got {}", rng.start, args.len()),
36 span,
37 )
38 } else {
39 typer.err(
40 format!(
41 "Expected between {} and {} arguments got {}",
42 rng.start,
43 rng.end,
44 args.len()
45 ),
46 span,
47 )
48 };
49
50 if let Some(args) = args.get(rng.end..) {
51 for (cnt, arg) in args.iter().enumerate() {
52 issue.frag(format!("Argument {}", rng.end + cnt), arg);
53 }
54 }
55}
56
57fn typed_args<'a, 'b, 'c>(
58 typer: &mut Typer<'a, 'b>,
59 args: &'c [Expression<'a>],
60 flags: ExpressionFlags,
61) -> Vec<(&'c Expression<'a>, FullType<'a>)> {
62 let mut typed: Vec<(&'_ Expression, FullType<'a>)> = Vec::new();
63 for arg in args {
64 typed.push((
67 arg,
68 type_expression(typer, arg, flags.without_values(), BaseType::Any),
69 ));
70 }
71 typed
72}
73
74pub(crate) fn type_function<'a, 'b>(
75 typer: &mut Typer<'a, 'b>,
76 func: &Function<'a>,
77 args: &[Expression<'a>],
78 span: &Span,
79 flags: ExpressionFlags,
80) -> FullType<'a> {
81 let mut tf = |return_type: Type<'a>,
82 required_args: &[BaseType],
83 optional_args: &[BaseType]|
84 -> FullType<'a> {
85 let mut not_null = true;
86 let mut arg_iter = args.iter();
87 arg_cnt(
88 typer,
89 required_args.len()..required_args.len() + optional_args.len(),
90 args,
91 span,
92 );
93 for et in required_args {
94 if let Some(arg) = arg_iter.next() {
95 let t = type_expression(typer, arg, flags.without_values(), *et);
96 not_null = not_null && t.not_null;
97 typer.ensure_base(arg, &t, *et);
98 }
99 }
100 for et in optional_args {
101 if let Some(arg) = arg_iter.next() {
102 let t = type_expression(typer, arg, flags.without_values(), *et);
103 not_null = not_null && t.not_null;
104 typer.ensure_base(arg, &t, *et);
105 }
106 }
107 for arg in arg_iter {
108 type_expression(typer, arg, flags.without_values(), BaseType::Any);
109 }
110 FullType::new(return_type, not_null)
111 };
112
113 match func {
114 Function::Rand => tf(Type::F64, &[], &[BaseType::Integer]),
115 Function::Right | Function::Left => tf(
116 BaseType::String.into(),
117 &[BaseType::String, BaseType::Integer],
118 &[],
119 ),
120 Function::SubStr => {
121 arg_cnt(typer, 2..3, args, span);
122
123 let mut return_type = if let Some(arg) = args.first() {
124 let t = type_expression(typer, arg, flags.without_values(), BaseType::Any);
125 if !matches!(t.base(), BaseType::Any | BaseType::String | BaseType::Bytes) {
126 typer.err(format!("Expected type String or Bytes got {}", t), arg);
127 }
128 t
129 } else {
130 FullType::invalid()
131 };
132
133 if let Some(arg) = args.get(1) {
134 let t = type_expression(typer, arg, flags.without_values(), BaseType::Integer);
135 return_type.not_null = return_type.not_null && t.not_null;
136 typer.ensure_base(arg, &t, BaseType::Integer);
137 };
138
139 if let Some(arg) = args.get(2) {
140 let t = type_expression(typer, arg, flags.without_values(), BaseType::Integer);
141 return_type.not_null = return_type.not_null && t.not_null;
142 typer.ensure_base(arg, &t, BaseType::Integer);
143 };
144
145 return_type
146 }
147 Function::FindInSet => tf(
148 BaseType::Integer.into(),
149 &[BaseType::String, BaseType::String],
150 &[],
151 ),
152 Function::SubStringIndex => tf(
153 BaseType::String.into(),
154 &[BaseType::String, BaseType::String, BaseType::Integer],
155 &[],
156 ),
157 Function::ExtractValue => tf(
158 BaseType::String.into(),
159 &[BaseType::String, BaseType::String],
160 &[],
161 ),
162 Function::Replace => tf(
163 BaseType::String.into(),
164 &[BaseType::String, BaseType::String, BaseType::String],
165 &[],
166 ),
167 Function::CharacterLength => tf(BaseType::Integer.into(), &[BaseType::String], &[]),
168 Function::UnixTimestamp => {
169 let mut not_null = true;
170 let typed = typed_args(typer, args, flags);
171 arg_cnt(typer, 0..1, args, span);
172 if let Some((a, t)) = typed.first() {
173 not_null = not_null && t.not_null;
174 typer.ensure_base(*a, t, BaseType::DateTime);
176 }
177 FullType::new(Type::I64, not_null)
178 }
179 Function::IfNull => {
180 let typed = typed_args(typer, args, flags);
181 arg_cnt(typer, 2..2, args, span);
182 let t = if let Some((e, t)) = typed.first() {
183 if t.not_null {
184 typer.warn("Cannot be null", *e);
185 }
186 t.clone()
187 } else {
188 FullType::invalid()
189 };
190 if let Some((e, t2)) = typed.get(1) {
191 typer.ensure_type(*e, t2, &t);
192 t2.clone()
193 } else {
194 t.clone()
195 }
196 }
197 Function::Lead | Function::Lag => {
198 let typed = typed_args(typer, args, flags);
199 arg_cnt(typer, 1..2, args, span);
200 if let Some((a, t)) = typed.get(1) {
201 typer.ensure_base(*a, t, BaseType::Integer);
202 }
203 if let Some((_, t)) = typed.first() {
204 let mut t = t.clone();
205 t.not_null = false;
206 t
207 } else {
208 FullType::invalid()
209 }
210 }
211 Function::JsonExtract => {
212 let typed = typed_args(typer, args, flags);
213 arg_cnt(typer, 2..999, args, span);
214 for (a, t) in &typed {
215 typer.ensure_base(*a, t, BaseType::String);
216 }
217 FullType::new(Type::JSON, false)
218 }
219 Function::JsonValue => {
220 let typed = typed_args(typer, args, flags);
221 arg_cnt(typer, 2..2, args, span);
222 for (a, t) in &typed {
223 typer.ensure_base(*a, t, BaseType::String);
224 }
225 FullType::new(Type::JSON, false)
226 }
227 Function::JsonReplace => {
228 let typed = typed_args(typer, args, flags);
229 arg_cnt(typer, 3..999, args, span);
230 for (i, (a, t)) in typed.iter().enumerate() {
231 if i == 0 || i % 2 == 1 {
232 typer.ensure_base(*a, t, BaseType::String);
233 }
234 }
235 FullType::new(Type::JSON, false)
236 }
237 Function::JsonSet => {
238 let typed = typed_args(typer, args, flags);
239 arg_cnt(typer, 3..999, args, span);
240 for (i, (a, t)) in typed.iter().enumerate() {
241 if i == 0 || i % 2 == 1 {
242 typer.ensure_base(*a, t, BaseType::String);
243 }
244 }
245 FullType::new(Type::JSON, false)
246 }
247 Function::JsonUnquote => {
248 let typed = typed_args(typer, args, flags);
249 arg_cnt(typer, 1..1, args, span);
250 for (a, t) in &typed {
251 typer.ensure_base(*a, t, BaseType::String);
252 }
253 FullType::new(BaseType::String, false)
254 }
255 Function::JsonQuery => {
256 let typed = typed_args(typer, args, flags);
257 arg_cnt(typer, 2..2, args, span);
258 for (a, t) in &typed {
259 typer.ensure_base(*a, t, BaseType::String);
260 }
261 FullType::new(Type::JSON, false)
262 }
263 Function::JsonRemove => {
264 let typed = typed_args(typer, args, flags);
265 arg_cnt(typer, 2..999, args, span);
266 for (a, t) in &typed {
267 typer.ensure_base(*a, t, BaseType::String);
268 }
269 FullType::new(Type::JSON, false)
270 }
271 Function::JsonContains => {
272 let typed = typed_args(typer, args, flags);
273 arg_cnt(typer, 2..3, args, span);
274 for (a, t) in &typed {
275 typer.ensure_base(*a, t, BaseType::String);
276 }
277 if let (Some(t0), Some(t1), t2) = (typed.first(), typed.get(1), typed.get(2)) {
278 let not_null =
279 t0.1.not_null && t1.1.not_null && t2.map(|t| t.1.not_null).unwrap_or(true);
280 FullType::new(Type::Base(BaseType::Bool), not_null)
281 } else {
282 FullType::invalid()
283 }
284 }
285 Function::JsonContainsPath => {
286 let typed = typed_args(typer, args, flags);
287 arg_cnt(typer, 3..999, args, span);
288 for (a, t) in &typed {
289 typer.ensure_base(*a, t, BaseType::String);
290 }
291 FullType::new(Type::JSON, false)
292 }
293 Function::JsonOverlaps => {
294 let typed = typed_args(typer, args, flags);
295 arg_cnt(typer, 2..2, args, span);
296 for (a, t) in &typed {
297 typer.ensure_base(*a, t, BaseType::String);
298 }
299 if let (Some(t0), Some(t1)) = (typed.first(), typed.get(1)) {
300 let not_null = t0.1.not_null && t1.1.not_null;
301 FullType::new(Type::Base(BaseType::Bool), not_null)
302 } else {
303 FullType::invalid()
304 }
305 }
306 Function::Min | Function::Max | Function::Sum => {
307 let typed = typed_args(typer, args, flags);
308 arg_cnt(typer, 1..1, args, span);
309 if let Some((_, t2)) = typed.first() {
310 let mut v = t2.clone();
313 v.not_null = false;
314 v
315 } else {
316 FullType::invalid()
317 }
318 }
319 Function::Now => tf(BaseType::DateTime.into(), &[], &[BaseType::Integer]),
320 Function::CurDate => tf(BaseType::Date.into(), &[], &[]),
321 Function::CurrentTimestamp => tf(BaseType::TimeStamp.into(), &[], &[BaseType::Integer]),
322 Function::Concat => {
323 let typed = typed_args(typer, args, flags);
324 let mut not_null = true;
325 for (a, t) in &typed {
326 typer.ensure_base(*a, t, BaseType::Any);
327 not_null = not_null && t.not_null;
328 }
329 FullType::new(BaseType::String, not_null)
330 }
331 Function::Least | Function::Greatest => {
332 let typed = typed_args(typer, args, flags);
333 arg_cnt(typer, 1..9999, args, span);
334 if let Some((a, at)) = typed.first() {
335 let mut not_null = true;
336 let mut t = at.t.clone();
337 for (b, bt) in &typed[1..] {
338 not_null = not_null && bt.not_null;
339 if bt.t == t {
340 continue;
341 };
342 if let Some(tt) = typer.matched_type(&bt.t, &t) {
343 t = tt;
344 } else {
345 typer
346 .err("None matching input types", span)
347 .frag(format!("Type {}", at.t), *a)
348 .frag(format!("Type {}", bt.t), *b);
349 }
350 }
351 FullType::new(t, true);
352 }
353 FullType::new(BaseType::Any, true)
354 }
355 Function::If => {
356 let typed = typed_args(typer, args, flags);
357 arg_cnt(typer, 3..3, args, span);
358 let mut not_null = true;
359 if let Some((e, t)) = typed.first() {
360 not_null = not_null && t.not_null;
361 typer.ensure_base(*e, t, BaseType::Bool);
362 }
363 let mut ans = FullType::invalid();
364 if let Some((e1, t1)) = typed.get(1) {
365 not_null = not_null && t1.not_null;
366 if let Some((e2, t2)) = typed.get(2) {
367 not_null = not_null && t2.not_null;
368 if let Some(t) = typer.matched_type(t1, t2) {
369 ans = FullType::new(t, not_null);
370 } else {
371 typer
372 .err("Incompatible types", span)
373 .frag(format!("Of type {}", t1.t), *e1)
374 .frag(format!("Of type {}", t2.t), *e2);
375 }
376 }
377 }
378 ans
379 }
380 Function::FromUnixTime => {
381 let typed = typed_args(typer, args, flags);
382 arg_cnt(typer, 1..2, args, span);
383 let mut not_null = true;
384 if let Some((e, t)) = typed.first() {
385 not_null = not_null && t.not_null;
386 typer.ensure_base(*e, t, BaseType::Float);
388 }
389 if let Some((e, t)) = typed.get(1) {
390 not_null = not_null && t.not_null;
391 typer.ensure_base(*e, t, BaseType::String);
392 FullType::new(BaseType::String, not_null)
393 } else {
394 FullType::new(BaseType::DateTime, not_null)
395 }
396 }
397 Function::DateFormat => tf(
398 BaseType::String.into(),
399 &[BaseType::DateTime, BaseType::String],
400 &[BaseType::String],
401 ),
402 Function::Value => {
403 let typed = typed_args(typer, args, flags);
404 if !flags.in_on_duplicate_key_update {
405 typer.err("VALUE is only allowed within ON DUPLICATE KEY UPDATE", span);
406 }
407 arg_cnt(typer, 1..1, args, span);
408 if let Some((_, t)) = typed.first() {
409 t.clone()
410 } else {
411 FullType::invalid()
412 }
413 }
414 Function::Length => {
415 let typed = typed_args(typer, args, flags);
416 arg_cnt(typer, 1..1, args, span);
417 let mut not_null = true;
418 for (_, t) in &typed {
419 not_null = not_null && t.not_null;
420 if typer
421 .matched_type(t, &FullType::new(BaseType::String, false))
422 .is_none()
423 && typer
424 .matched_type(t, &FullType::new(BaseType::Bytes, false))
425 .is_none()
426 {
427 typer.err(format!("Expected type Bytes or String got {}", t), span);
428 }
429 }
430 FullType::new(Type::I64, not_null)
431 }
432 Function::Strftime => {
433 let typed = typed_args(typer, args, flags);
434 arg_cnt(typer, 2..2, args, span);
435 let mut not_null = true;
436 if let Some((e, t)) = typed.first() {
437 not_null = not_null && t.not_null;
438 typer.ensure_base(*e, t, BaseType::String);
439 }
440 if let Some((e, t)) = typed.last() {
441 not_null = not_null && t.not_null;
442 typer.ensure_base(*e, t, BaseType::DateTime);
443 }
444 FullType::new(BaseType::String, not_null)
445 }
446 Function::Datetime => {
447 let typed = typed_args(typer, args, flags);
448 arg_cnt(typer, 1..1, args, span);
449 let mut not_null = true;
450 if let Some((e, t)) = typed.first() {
451 not_null = not_null && t.not_null;
452 typer.ensure_base(*e, t, BaseType::String);
453 }
454 FullType::new(BaseType::DateTime, not_null)
455 }
456 _ => {
457 typer.err("Typing for function not implemented", span);
458 FullType::invalid()
459 }
460 }
461}