sql_type/
type_function.rs

1// Licensed under the Apache License, Version 2.0 (the "License");
2// you may not use this file except in compliance with the License.
3// You may obtain a copy of the License at
4//
5// http://www.apache.org/licenses/LICENSE-2.0
6//
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use alloc::{format, vec::Vec};
14use sql_parse::{Expression, Function, Span};
15
16use crate::{
17    type_::{BaseType, FullType},
18    type_expression::{type_expression, ExpressionFlags},
19    typer::Typer,
20    Type,
21};
22
23fn arg_cnt<'a>(
24    typer: &mut Typer<'a, '_>,
25    rng: core::ops::Range<usize>,
26    args: &[Expression<'a>],
27    span: &Span,
28) {
29    if args.len() >= rng.start && args.len() <= rng.end {
30        return;
31    }
32
33    let mut issue = if rng.is_empty() {
34        typer.err(
35            format!("Expected {} arguments got {}", rng.start, args.len()),
36            span,
37        )
38    } else {
39        typer.err(
40            format!(
41                "Expected between {} and {} arguments got {}",
42                rng.start,
43                rng.end,
44                args.len()
45            ),
46            span,
47        )
48    };
49
50    if let Some(args) = args.get(rng.end..) {
51        for (cnt, arg) in args.iter().enumerate() {
52            issue.frag(format!("Argument {}", rng.end + cnt), arg);
53        }
54    }
55}
56
57fn typed_args<'a, 'b, 'c>(
58    typer: &mut Typer<'a, 'b>,
59    args: &'c [Expression<'a>],
60    flags: ExpressionFlags,
61) -> Vec<(&'c Expression<'a>, FullType<'a>)> {
62    let mut typed: Vec<(&'_ Expression, FullType<'a>)> = Vec::new();
63    for arg in args {
64        // TODO we need not always disable the not null flag here
65        // TODO we should not supply base type any here, this function needs to die
66        typed.push((
67            arg,
68            type_expression(typer, arg, flags.without_values(), BaseType::Any),
69        ));
70    }
71    typed
72}
73
74pub(crate) fn type_function<'a, 'b>(
75    typer: &mut Typer<'a, 'b>,
76    func: &Function<'a>,
77    args: &[Expression<'a>],
78    span: &Span,
79    flags: ExpressionFlags,
80) -> FullType<'a> {
81    let mut tf = |return_type: Type<'a>,
82                  required_args: &[BaseType],
83                  optional_args: &[BaseType]|
84     -> FullType<'a> {
85        let mut not_null = true;
86        let mut arg_iter = args.iter();
87        arg_cnt(
88            typer,
89            required_args.len()..required_args.len() + optional_args.len(),
90            args,
91            span,
92        );
93        for et in required_args {
94            if let Some(arg) = arg_iter.next() {
95                let t = type_expression(typer, arg, flags.without_values(), *et);
96                not_null = not_null && t.not_null;
97                typer.ensure_base(arg, &t, *et);
98            }
99        }
100        for et in optional_args {
101            if let Some(arg) = arg_iter.next() {
102                let t = type_expression(typer, arg, flags.without_values(), *et);
103                not_null = not_null && t.not_null;
104                typer.ensure_base(arg, &t, *et);
105            }
106        }
107        for arg in arg_iter {
108            type_expression(typer, arg, flags.without_values(), BaseType::Any);
109        }
110        FullType::new(return_type, not_null)
111    };
112
113    match func {
114        Function::Rand => tf(Type::F64, &[], &[BaseType::Integer]),
115        Function::Right | Function::Left => tf(
116            BaseType::String.into(),
117            &[BaseType::String, BaseType::Integer],
118            &[],
119        ),
120        Function::SubStr => {
121            arg_cnt(typer, 2..3, args, span);
122
123            let mut return_type = if let Some(arg) = args.first() {
124                let t = type_expression(typer, arg, flags.without_values(), BaseType::Any);
125                if !matches!(t.base(), BaseType::Any | BaseType::String | BaseType::Bytes) {
126                    typer.err(format!("Expected type String or Bytes got {}", t), arg);
127                }
128                t
129            } else {
130                FullType::invalid()
131            };
132
133            if let Some(arg) = args.get(1) {
134                let t = type_expression(typer, arg, flags.without_values(), BaseType::Integer);
135                return_type.not_null = return_type.not_null && t.not_null;
136                typer.ensure_base(arg, &t, BaseType::Integer);
137            };
138
139            if let Some(arg) = args.get(2) {
140                let t = type_expression(typer, arg, flags.without_values(), BaseType::Integer);
141                return_type.not_null = return_type.not_null && t.not_null;
142                typer.ensure_base(arg, &t, BaseType::Integer);
143            };
144
145            return_type
146        }
147        Function::FindInSet => tf(
148            BaseType::Integer.into(),
149            &[BaseType::String, BaseType::String],
150            &[],
151        ),
152        Function::SubStringIndex => tf(
153            BaseType::String.into(),
154            &[BaseType::String, BaseType::String, BaseType::Integer],
155            &[],
156        ),
157        Function::ExtractValue => tf(
158            BaseType::String.into(),
159            &[BaseType::String, BaseType::String],
160            &[],
161        ),
162        Function::Replace => tf(
163            BaseType::String.into(),
164            &[BaseType::String, BaseType::String, BaseType::String],
165            &[],
166        ),
167        Function::CharacterLength => tf(BaseType::Integer.into(), &[BaseType::String], &[]),
168        Function::UnixTimestamp => {
169            let mut not_null = true;
170            let typed = typed_args(typer, args, flags);
171            arg_cnt(typer, 0..1, args, span);
172            if let Some((a, t)) = typed.first() {
173                not_null = not_null && t.not_null;
174                // TODO the argument can be both a DATE, a DATE_TIME or a TIMESTAMP
175                typer.ensure_base(*a, t, BaseType::DateTime);
176            }
177            FullType::new(Type::I64, not_null)
178        }
179        Function::IfNull => {
180            let typed = typed_args(typer, args, flags);
181            arg_cnt(typer, 2..2, args, span);
182            let t = if let Some((e, t)) = typed.first() {
183                if t.not_null {
184                    typer.warn("Cannot be null", *e);
185                }
186                t.clone()
187            } else {
188                FullType::invalid()
189            };
190            if let Some((e, t2)) = typed.get(1) {
191                typer.ensure_type(*e, t2, &t);
192                t2.clone()
193            } else {
194                t.clone()
195            }
196        }
197        Function::Lead | Function::Lag => {
198            let typed = typed_args(typer, args, flags);
199            arg_cnt(typer, 1..2, args, span);
200            if let Some((a, t)) = typed.get(1) {
201                typer.ensure_base(*a, t, BaseType::Integer);
202            }
203            if let Some((_, t)) = typed.first() {
204                let mut t = t.clone();
205                t.not_null = false;
206                t
207            } else {
208                FullType::invalid()
209            }
210        }
211        Function::JsonExtract => {
212            let typed = typed_args(typer, args, flags);
213            arg_cnt(typer, 2..999, args, span);
214            for (a, t) in &typed {
215                typer.ensure_base(*a, t, BaseType::String);
216            }
217            FullType::new(Type::JSON, false)
218        }
219        Function::JsonValue => {
220            let typed = typed_args(typer, args, flags);
221            arg_cnt(typer, 2..2, args, span);
222            for (a, t) in &typed {
223                typer.ensure_base(*a, t, BaseType::String);
224            }
225            FullType::new(Type::JSON, false)
226        }
227        Function::JsonReplace => {
228            let typed = typed_args(typer, args, flags);
229            arg_cnt(typer, 3..999, args, span);
230            for (i, (a, t)) in typed.iter().enumerate() {
231                if i == 0 || i % 2 == 1 {
232                    typer.ensure_base(*a, t, BaseType::String);
233                }
234            }
235            FullType::new(Type::JSON, false)
236        }
237        Function::JsonSet => {
238            let typed = typed_args(typer, args, flags);
239            arg_cnt(typer, 3..999, args, span);
240            for (i, (a, t)) in typed.iter().enumerate() {
241                if i == 0 || i % 2 == 1 {
242                    typer.ensure_base(*a, t, BaseType::String);
243                }
244            }
245            FullType::new(Type::JSON, false)
246        }
247        Function::JsonUnquote => {
248            let typed = typed_args(typer, args, flags);
249            arg_cnt(typer, 1..1, args, span);
250            for (a, t) in &typed {
251                typer.ensure_base(*a, t, BaseType::String);
252            }
253            FullType::new(BaseType::String, false)
254        }
255        Function::JsonQuery => {
256            let typed = typed_args(typer, args, flags);
257            arg_cnt(typer, 2..2, args, span);
258            for (a, t) in &typed {
259                typer.ensure_base(*a, t, BaseType::String);
260            }
261            FullType::new(Type::JSON, false)
262        }
263        Function::JsonRemove => {
264            let typed = typed_args(typer, args, flags);
265            arg_cnt(typer, 2..999, args, span);
266            for (a, t) in &typed {
267                typer.ensure_base(*a, t, BaseType::String);
268            }
269            FullType::new(Type::JSON, false)
270        }
271        Function::JsonContains => {
272            let typed = typed_args(typer, args, flags);
273            arg_cnt(typer, 2..3, args, span);
274            for (a, t) in &typed {
275                typer.ensure_base(*a, t, BaseType::String);
276            }
277            if let (Some(t0), Some(t1), t2) = (typed.first(), typed.get(1), typed.get(2)) {
278                let not_null =
279                    t0.1.not_null && t1.1.not_null && t2.map(|t| t.1.not_null).unwrap_or(true);
280                FullType::new(Type::Base(BaseType::Bool), not_null)
281            } else {
282                FullType::invalid()
283            }
284        }
285        Function::JsonContainsPath => {
286            let typed = typed_args(typer, args, flags);
287            arg_cnt(typer, 3..999, args, span);
288            for (a, t) in &typed {
289                typer.ensure_base(*a, t, BaseType::String);
290            }
291            FullType::new(Type::JSON, false)
292        }
293        Function::JsonOverlaps => {
294            let typed = typed_args(typer, args, flags);
295            arg_cnt(typer, 2..2, args, span);
296            for (a, t) in &typed {
297                typer.ensure_base(*a, t, BaseType::String);
298            }
299            if let (Some(t0), Some(t1)) = (typed.first(), typed.get(1)) {
300                let not_null = t0.1.not_null && t1.1.not_null;
301                FullType::new(Type::Base(BaseType::Bool), not_null)
302            } else {
303                FullType::invalid()
304            }
305        }
306        Function::Min | Function::Max | Function::Sum => {
307            let typed = typed_args(typer, args, flags);
308            arg_cnt(typer, 1..1, args, span);
309            if let Some((_, t2)) = typed.first() {
310                // TODO check that the type can be mined or maxed
311                // Result can be null if there are no rows to aggregate over
312                let mut v = t2.clone();
313                v.not_null = false;
314                v
315            } else {
316                FullType::invalid()
317            }
318        }
319        Function::Now => tf(BaseType::DateTime.into(), &[], &[BaseType::Integer]),
320        Function::CurDate => tf(BaseType::Date.into(), &[], &[]),
321        Function::CurrentTimestamp => tf(BaseType::TimeStamp.into(), &[], &[BaseType::Integer]),
322        Function::Concat => {
323            let typed = typed_args(typer, args, flags);
324            let mut not_null = true;
325            for (a, t) in &typed {
326                typer.ensure_base(*a, t, BaseType::Any);
327                not_null = not_null && t.not_null;
328            }
329            FullType::new(BaseType::String, not_null)
330        }
331        Function::Least | Function::Greatest => {
332            let typed = typed_args(typer, args, flags);
333            arg_cnt(typer, 1..9999, args, span);
334            if let Some((a, at)) = typed.first() {
335                let mut not_null = true;
336                let mut t = at.t.clone();
337                for (b, bt) in &typed[1..] {
338                    not_null = not_null && bt.not_null;
339                    if bt.t == t {
340                        continue;
341                    };
342                    if let Some(tt) = typer.matched_type(&bt.t, &t) {
343                        t = tt;
344                    } else {
345                        typer
346                            .err("None matching input types", span)
347                            .frag(format!("Type {}", at.t), *a)
348                            .frag(format!("Type {}", bt.t), *b);
349                    }
350                }
351                FullType::new(t, true);
352            }
353            FullType::new(BaseType::Any, true)
354        }
355        Function::If => {
356            let typed = typed_args(typer, args, flags);
357            arg_cnt(typer, 3..3, args, span);
358            let mut not_null = true;
359            if let Some((e, t)) = typed.first() {
360                not_null = not_null && t.not_null;
361                typer.ensure_base(*e, t, BaseType::Bool);
362            }
363            let mut ans = FullType::invalid();
364            if let Some((e1, t1)) = typed.get(1) {
365                not_null = not_null && t1.not_null;
366                if let Some((e2, t2)) = typed.get(2) {
367                    not_null = not_null && t2.not_null;
368                    if let Some(t) = typer.matched_type(t1, t2) {
369                        ans = FullType::new(t, not_null);
370                    } else {
371                        typer
372                            .err("Incompatible types", span)
373                            .frag(format!("Of type {}", t1.t), *e1)
374                            .frag(format!("Of type {}", t2.t), *e2);
375                    }
376                }
377            }
378            ans
379        }
380        Function::FromUnixTime => {
381            let typed = typed_args(typer, args, flags);
382            arg_cnt(typer, 1..2, args, span);
383            let mut not_null = true;
384            if let Some((e, t)) = typed.first() {
385                not_null = not_null && t.not_null;
386                // TODO float og int
387                typer.ensure_base(*e, t, BaseType::Float);
388            }
389            if let Some((e, t)) = typed.get(1) {
390                not_null = not_null && t.not_null;
391                typer.ensure_base(*e, t, BaseType::String);
392                FullType::new(BaseType::String, not_null)
393            } else {
394                FullType::new(BaseType::DateTime, not_null)
395            }
396        }
397        Function::DateFormat => tf(
398            BaseType::String.into(),
399            &[BaseType::DateTime, BaseType::String],
400            &[BaseType::String],
401        ),
402        Function::Value => {
403            let typed = typed_args(typer, args, flags);
404            if !flags.in_on_duplicate_key_update {
405                typer.err("VALUE is only allowed within ON DUPLICATE KEY UPDATE", span);
406            }
407            arg_cnt(typer, 1..1, args, span);
408            if let Some((_, t)) = typed.first() {
409                t.clone()
410            } else {
411                FullType::invalid()
412            }
413        }
414        Function::Length => {
415            let typed = typed_args(typer, args, flags);
416            arg_cnt(typer, 1..1, args, span);
417            let mut not_null = true;
418            for (_, t) in &typed {
419                not_null = not_null && t.not_null;
420                if typer
421                    .matched_type(t, &FullType::new(BaseType::String, false))
422                    .is_none()
423                    && typer
424                        .matched_type(t, &FullType::new(BaseType::Bytes, false))
425                        .is_none()
426                {
427                    typer.err(format!("Expected type Bytes or String got {}", t), span);
428                }
429            }
430            FullType::new(Type::I64, not_null)
431        }
432        Function::Strftime => {
433            let typed = typed_args(typer, args, flags);
434            arg_cnt(typer, 2..2, args, span);
435            let mut not_null = true;
436            if let Some((e, t)) = typed.first() {
437                not_null = not_null && t.not_null;
438                typer.ensure_base(*e, t, BaseType::String);
439            }
440            if let Some((e, t)) = typed.last() {
441                not_null = not_null && t.not_null;
442                typer.ensure_base(*e, t, BaseType::DateTime);
443            }
444            FullType::new(BaseType::String, not_null)
445        }
446        Function::Datetime => {
447            let typed = typed_args(typer, args, flags);
448            arg_cnt(typer, 1..1, args, span);
449            let mut not_null = true;
450            if let Some((e, t)) = typed.first() {
451                not_null = not_null && t.not_null;
452                typer.ensure_base(*e, t, BaseType::String);
453            }
454            FullType::new(BaseType::DateTime, not_null)
455        }
456        _ => {
457            typer.err("Typing for function not implemented", span);
458            FullType::invalid()
459        }
460    }
461}