1mod analyze;
2mod arith_expr;
3mod array;
4mod boolean;
5mod case_expr;
6mod cast_expr;
7mod coalesce;
8mod collate;
9mod column_ref;
10mod const_expr;
11mod func_call;
12mod indirect;
13mod min_max;
14mod named_arg;
15mod param_ref;
16mod sql_value;
17mod sub_link;
18
19use sql_fun_core::IVec;
20
21pub use self::{
22 analyze::analyze_scaler_expr,
23 arith_expr::ArithExpr,
24 array::ArrayExpr,
25 boolean::{BooleanExpr, NullTestExpr},
26 case_expr::CaseExpr,
27 cast_expr::{ImplicitCastExpr, TypeCastExpr},
28 coalesce::CoalesceExpr,
29 collate::CollateExpr,
30 column_ref::{ColumnReferenceExpr, CteColumnRef, SubQueryColumnRef, TableColumnRef},
31 const_expr::ScalarConstExpr,
32 func_call::FuncCallExpr,
33 indirect::IndirectionExpr,
34 min_max::MinMaxExpr,
35 named_arg::NamedArgExpr,
36 param_ref::ParamRef,
37 sql_value::SqlValueExpr,
38 sub_link::SubLinkExpr,
39};
40
41use crate::{
42 sem::{
43 AnalysisError, AnalysisProblem, FromClause, ParseContext, PgBuiltInType, TypeReference,
44 WithClause, create_table::ColumnDefinition,
45 },
46 syn::ScanToken,
47};
48
49trait AnalyzeScalarExpr<TParseContext, TNode>
50where
51 TParseContext: ParseContext,
52{
53 fn analyze_scalar_expr(
54 context: TParseContext,
55 with_clause: &WithClause,
56 from_clause: &FromClause,
57 syn: TNode,
58 tokens: &IVec<ScanToken>,
59 ) -> Result<(SemScalarExpr, TParseContext), AnalysisError>;
60}
61
62trait SemScalarExprNode {
63 fn get_type(&self) -> Option<TypeReference>;
64 fn is_not_null(&self) -> Option<bool>;
65}
66
67#[derive(Debug, Clone, Eq, Hash, PartialEq, serde::Serialize, serde::Deserialize)]
69pub enum SemScalarExpr {
70 Array(ArrayExpr),
72 Const(ScalarConstExpr),
74 ColumnRef(ColumnReferenceExpr),
76 FuncCall(FuncCallExpr),
78 Arith(ArithExpr),
80 TypeCast(TypeCastExpr),
82 Case(CaseExpr),
84 Boolean(BooleanExpr),
86 NullTest(NullTestExpr),
88 SubLink(SubLinkExpr),
90 Coalesce(CoalesceExpr),
92 MinMax(MinMaxExpr),
94 Collate(CollateExpr),
96 SqlValue(SqlValueExpr),
98 Param(ParamRef),
100 NamedArg(NamedArgExpr),
102 Indirection(IndirectionExpr),
104 ImplicitCast(ImplicitCastExpr),
113
114 Unexpected(String),
116}
117
118impl SemScalarExpr {
119 #[cfg(test)]
121 pub fn new_const(value: ScalarConstExpr) -> Self {
122 Self::Const(value)
123 }
124
125 fn builtin_boolean_type() -> Option<TypeReference> {
126 Some(TypeReference::concrete_type_ref(
127 PgBuiltInType::bool().full_name(),
128 false,
129 ))
130 }
131
132 #[must_use]
134 pub fn get_type(&self) -> Option<TypeReference> {
135 match self {
136 Self::Const(scalar_const_expr) => scalar_const_expr.get_type(),
137 Self::FuncCall(func_call_expr) => func_call_expr.get_type(),
138 Self::Arith(arith_expr) => arith_expr.get_type(),
139 Self::TypeCast(type_cast_expr) => type_cast_expr.get_type(),
140 Self::Case(case_expr) => case_expr.get_type(),
141 Self::Boolean(_boolean_expr) => Self::builtin_boolean_type(),
142 Self::NullTest(_null_test_expr) => Self::builtin_boolean_type(),
143 Self::SubLink(sub_link_expr) => sub_link_expr.get_type(),
144 Self::Coalesce(coalesce_expr) => coalesce_expr.get_type(),
145 Self::MinMax(min_max_expr) => min_max_expr.get_type(),
146 Self::Collate(collate_expr) => collate_expr.get_type(),
147 Self::SqlValue(sql_value_expr) => sql_value_expr.get_type(),
148 Self::Param(param_ref) => param_ref.get_type(),
149 Self::NamedArg(named_arg_expr) => named_arg_expr.get_type(),
150 Self::Indirection(ind) => ind.get_type(),
151 Self::ImplicitCast(ice) => ice.get_type(),
152 Self::Array(arr) => arr.get_type(),
153 Self::Unexpected(_node) => None,
154 Self::ColumnRef(cr) => cr.get_type(),
155 }
156 }
157
158 #[must_use]
160 pub fn is_not_null(&self) -> Option<bool> {
161 match self {
162 Self::Const(scalar_const_expr) => scalar_const_expr.is_not_null(),
163 Self::FuncCall(func_call_expr) => func_call_expr.is_not_null(),
164 Self::Arith(arith_expr) => arith_expr.is_not_null(),
165 Self::TypeCast(type_cast_expr) => type_cast_expr.is_not_null(),
166 Self::Case(case_expr) => case_expr.is_not_null(),
167 Self::Boolean(boolean_expr) => boolean_expr.is_not_null(),
168 Self::NullTest(null_test_expr) => null_test_expr.is_not_null(),
169 Self::SubLink(sub_link_expr) => sub_link_expr.is_not_null(),
170 Self::Coalesce(coalesce_expr) => coalesce_expr.is_not_null(),
171 Self::MinMax(min_max_expr) => min_max_expr.is_not_null(),
172 Self::Collate(collate_expr) => collate_expr.is_not_null(),
173 Self::SqlValue(sql_value_expr) => sql_value_expr.is_not_null(),
174 Self::Param(param_ref) => param_ref.is_not_null(),
175 Self::NamedArg(named_arg_expr) => named_arg_expr.is_not_null(),
176 Self::Indirection(ind) => ind.is_not_null(),
177 Self::ImplicitCast(ice) => ice.is_not_null(),
178 Self::Array(arr) => arr.is_not_null(),
179 Self::Unexpected(_node) => None,
180 Self::ColumnRef(cr) => cr.is_not_null(),
181 }
182 }
183}
184
185#[cfg(test)]
186mod test_get_type_and_is_not_null {
187 use super::{ScalarConstExpr, SemScalarExpr};
188 use crate::sem::PgBuiltInType;
189
190 #[test]
191 fn get_type_and_is_not_null_for_const() {
192 let int_expr = SemScalarExpr::Const(ScalarConstExpr::new_integer(10));
193 let null_expr = SemScalarExpr::Const(ScalarConstExpr::Null);
194
195 assert_eq!(int_expr.get_type(), Some(PgBuiltInType::int4()));
196 assert_eq!(int_expr.is_not_null(), Some(true));
197 assert_eq!(null_expr.get_type(), None);
198 assert_eq!(null_expr.is_not_null(), Some(false));
199 }
200}
201
202impl SemScalarExpr {
203 #[must_use]
205 pub fn get_column_def(&self) -> Option<ColumnDefinition> {
206 match self {
207 Self::ColumnRef(c) => c.get_column_def().clone(),
208 Self::TypeCast(_) | SemScalarExpr::FuncCall(_) => None,
209 _ => todo!("not implemented {self:?}"),
210 }
211 }
212
213 #[must_use]
215 pub fn get_column_name(&self) -> String {
216 match self {
217 Self::Const(_scalar_const_expr) => String::new(),
218 Self::ColumnRef(cr) => cr.get_column_name().to_string(),
219 Self::FuncCall(_func_call_expr) => String::new(),
220 Self::Arith(_arith_expr) => todo!(),
221 Self::TypeCast(_type_cast_expr) => String::new(),
222 Self::Case(_case_expr) => todo!(),
223 Self::Boolean(_boolean_expr) => String::new(),
224 Self::NullTest(_null_test_expr) => todo!(),
225 Self::SubLink(_sub_link_expr) => todo!(),
226 Self::Coalesce(_coalesce_expr) => todo!(),
227 Self::MinMax(_min_max_expr) => todo!(),
228 Self::Collate(_collate_expr) => todo!(),
229 Self::SqlValue(_sql_value_expr) => todo!(),
230 Self::Param(_param_ref) => todo!(),
231 Self::NamedArg(_named_arg_expr) => todo!(),
232 Self::Indirection(_indirection_expr) => todo!(),
233 Self::ImplicitCast(_implicit_cast_expr) => todo!(),
234 Self::Array(_arr) => todo!(),
235 Self::Unexpected(_node) => todo!(),
236 }
237 }
238}
239
240#[cfg(test)]
241mod test_get_column_def_and_name {
242 use super::{ColumnReferenceExpr, SemScalarExpr, TableColumnRef, TypeReference};
243 use crate::sem::{
244 FullName, PgBuiltInType,
245 create_table::{ColumnDefinition, ColumnName, TableName},
246 data_source::AliasName,
247 };
248
249 fn make_table_column_expr(
250 column_name: &str,
251 column_type: &TypeReference,
252 is_not_null: Option<bool>,
253 ) -> SemScalarExpr {
254 let alias = AliasName::from("t");
255 let column = ColumnName::from(column_name);
256 let table_name = TableName::from(FullName::with_schema("public", "tbl"));
257 let col_def = ColumnDefinition::new(&Some(column.clone()), Some(column_type), is_not_null);
258 let table_column = TableColumnRef::new(&alias, &column, &table_name, Some(&col_def), false);
259 SemScalarExpr::ColumnRef(ColumnReferenceExpr::TableColumn(table_column))
260 }
261
262 #[test]
263 fn get_column_def_and_name_from_column_ref() {
264 let col_type = PgBuiltInType::int4();
265 let expr = make_table_column_expr("col", &col_type, Some(true));
266
267 let col_def = expr.get_column_def().expect("column definition");
268 assert_eq!(col_def.get_type(), Some(col_type));
269 assert_eq!(expr.get_column_name(), "col");
270 }
271}
272
273impl SemScalarExpr {
274 #[must_use]
276 pub fn wrap_implicit_cast(taget_type: &TypeReference, value_expr: &SemScalarExpr) -> Self {
277 Self::ImplicitCast(ImplicitCastExpr::new(value_expr, taget_type))
278 }
279}
280
281#[cfg(test)]
282mod test_wrap_implicit_cast {
283 use super::{ScalarConstExpr, SemScalarExpr};
284 use crate::sem::PgBuiltInType;
285
286 #[test]
287 fn wrap_implicit_cast_sets_type() {
288 let base_expr = SemScalarExpr::Const(ScalarConstExpr::new_integer(1));
289 let target_type = PgBuiltInType::int8();
290 let wrapped = SemScalarExpr::wrap_implicit_cast(&target_type, &base_expr);
291
292 match wrapped {
293 SemScalarExpr::ImplicitCast(_) => {
294 assert_eq!(wrapped.get_type(), Some(target_type));
295 }
296 _ => panic!("expected implicit cast expression"),
297 }
298 }
299}
300
301impl SemScalarExpr {
302 pub fn implicit_cast_if_require<TParseContext>(
304 context: &mut TParseContext,
305 result_type: &TypeReference,
306 expr: &mut SemScalarExpr,
307 ) -> Result<(), AnalysisError>
308 where
309 TParseContext: ParseContext,
310 {
311 let Some(ty) = expr.get_type() else {
312 return Ok(());
313 };
314 if &ty == result_type {
315 return Ok(());
316 }
317
318 let Some(cast) = context.get_implicit_cast(&ty, result_type) else {
319 context.report_problem(AnalysisProblem::implicit_cast_not_found(&ty, result_type))?;
320 return Ok(());
321 };
322 if !cast.is_no_coversion() {
323 *expr = SemScalarExpr::wrap_implicit_cast(result_type, expr);
324 }
325 Ok(())
326 }
327}
328
329#[cfg(test)]
330mod test_implicit_cast_if_require {
331 use super::{ScalarConstExpr, SemScalarExpr};
332 use crate::sem::{CastContext, CastDefinition, PgBuiltInType};
333 use crate::test_helpers::TestParseContext;
334
335 #[test]
336 fn implicit_cast_if_require_no_change_when_same_type() {
337 let mut context = TestParseContext::default();
338 let target_type = PgBuiltInType::int4();
339 let mut expr = SemScalarExpr::Const(ScalarConstExpr::new_integer(3));
340
341 SemScalarExpr::implicit_cast_if_require(&mut context, &target_type, &mut expr)
342 .expect("implicit cast check");
343
344 assert!(matches!(expr, SemScalarExpr::Const(_)));
345 assert_eq!(context.reported_problem_count(), 0);
346 }
347
348 #[test]
349 fn implicit_cast_if_require_wraps_on_available_cast() {
350 let mut context = TestParseContext::default();
351 let source_type = PgBuiltInType::int4();
352 let target_type = PgBuiltInType::int8();
353 context.set_get_implicit_cast_result(
354 &source_type,
355 &target_type,
356 Some(CastDefinition::new(CastContext::Implicit)),
357 );
358
359 let mut expr = SemScalarExpr::Const(ScalarConstExpr::new_integer(3));
360 SemScalarExpr::implicit_cast_if_require(&mut context, &target_type, &mut expr)
361 .expect("implicit cast check");
362
363 assert!(matches!(expr, SemScalarExpr::ImplicitCast(_)));
364 assert_eq!(expr.get_type(), Some(target_type));
365 assert_eq!(context.reported_problem_count(), 0);
366 }
367
368 #[test]
369 fn implicit_cast_if_require_reports_when_missing_cast() {
370 let mut context = TestParseContext::default();
371 let target_type = PgBuiltInType::int8();
372 let mut expr = SemScalarExpr::Const(ScalarConstExpr::new_integer(3));
373
374 SemScalarExpr::implicit_cast_if_require(&mut context, &target_type, &mut expr)
375 .expect("implicit cast check");
376
377 assert!(matches!(expr, SemScalarExpr::Const(_)));
378 assert_eq!(context.reported_problem_count(), 1);
379 }
380}
381
382impl SemScalarExpr {
383 pub fn require_array<TParseContext>(
385 &self,
386 context: &mut TParseContext,
387 ) -> Result<(), AnalysisError>
388 where
389 TParseContext: ParseContext,
390 {
391 if let Some(ty) = self.get_type()
392 && let Some(is_array) = ty.is_array()
393 && !is_array
394 {
395 context.report_problem(AnalysisProblem::array_required(self, &ty))?;
396 }
397 Ok(())
398 }
399}
400
401#[cfg(test)]
402mod test_require_array {
403 use super::{ImplicitCastExpr, ScalarConstExpr, SemScalarExpr, TypeReference};
404 use crate::sem::PgBuiltInType;
405 use crate::test_helpers::TestParseContext;
406
407 #[test]
408 fn require_array_reports_for_non_array() {
409 let mut context = TestParseContext::default();
410 let expr = SemScalarExpr::Const(ScalarConstExpr::new_integer(5));
411
412 expr.require_array(&mut context).expect("require array");
413 assert_eq!(context.reported_problem_count(), 1);
414 }
415
416 #[test]
417 fn require_array_accepts_array_type() {
418 let mut context = TestParseContext::default();
419 let base_expr = SemScalarExpr::Const(ScalarConstExpr::new_integer(5));
420 let array_type = TypeReference::concrete_type_ref(PgBuiltInType::int4().full_name(), true);
421 let expr = SemScalarExpr::ImplicitCast(ImplicitCastExpr::new(&base_expr, &array_type));
422
423 expr.require_array(&mut context).expect("require array");
424 assert_eq!(context.reported_problem_count(), 0);
425 }
426}
427
428impl SemScalarExpr {
429 pub fn fits_array_index<TParseContext>(
431 &self,
432 context: &mut TParseContext,
433 ) -> Result<(), AnalysisError>
434 where
435 TParseContext: ParseContext,
436 {
437 if let Some(ty) = self.get_type()
438 && context
439 .get_implicit_cast(&ty, &PgBuiltInType::int2())
440 .is_none()
441 && context
442 .get_implicit_cast(&ty, &PgBuiltInType::int4())
443 .is_none()
444 && context
445 .get_implicit_cast(&ty, &PgBuiltInType::int8())
446 .is_none()
447 {
448 context.report_problem(AnalysisProblem::array_index_type_missmatch(self, &ty))?;
449 }
450 Ok(())
451 }
452}
453
454#[cfg(test)]
455mod test_fits_array_index {
456 use super::{ScalarConstExpr, SemScalarExpr};
457 use crate::sem::{CastContext, CastDefinition, PgBuiltInType};
458 use crate::test_helpers::TestParseContext;
459
460 #[test]
461 fn fits_array_index_reports_when_not_castable() {
462 let mut context = TestParseContext::default();
463 let expr = SemScalarExpr::Const(ScalarConstExpr::String("value".to_string()));
464
465 expr.fits_array_index(&mut context)
466 .expect("array index check");
467 assert_eq!(context.reported_problem_count(), 1);
468 }
469
470 #[test]
471 fn fits_array_index_accepts_castable_type() {
472 let mut context = TestParseContext::default();
473 let int4 = PgBuiltInType::int4();
474 context.set_get_implicit_cast_result(
475 &int4,
476 &int4,
477 Some(CastDefinition::new(CastContext::NoConversion)),
478 );
479 let expr = SemScalarExpr::Const(ScalarConstExpr::new_integer(7));
480
481 expr.fits_array_index(&mut context)
482 .expect("array index check");
483 assert_eq!(context.reported_problem_count(), 0);
484 }
485}