Skip to main content

surql_parser/upstream/sql/
expression.rs

1use crate::compat::types::{PublicFile, PublicNumber, PublicRecordId, PublicValue};
2use crate::upstream::fmt::{CoverStmts, EscapeIdent};
3use crate::upstream::sql::ast::ExplainFormat;
4use crate::upstream::sql::literal::ObjectEntry;
5use crate::upstream::sql::lookup::LookupKind;
6use crate::upstream::sql::operator::BindingPower;
7use crate::upstream::sql::statements::{
8	AlterStatement, CreateStatement, DefineStatement, DeleteStatement, ForeachStatement,
9	IfelseStatement, InfoStatement, InsertStatement, OutputStatement, RebuildStatement,
10	RelateStatement, RemoveStatement, SelectStatement, SetStatement, SleepStatement,
11	UpdateStatement, UpsertStatement,
12};
13use crate::upstream::sql::{
14	BinaryOperator, Block, Closure, Constant, Dir, FunctionCall, Idiom, Literal, Mock, Param, Part,
15	PostfixOperator, PrefixOperator, RecordIdKeyLit, RecordIdLit,
16};
17use std::ops::Bound;
18use surrealdb_types::{SqlFormat, ToSql, write_sql};
19#[derive(Clone, Debug, Eq, PartialEq)]
20#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
21pub enum Expr {
22	Literal(Literal),
23	Param(Param),
24	Idiom(Idiom),
25	Table(String),
26	Mock(Mock),
27	Block(Box<Block>),
28	Constant(Constant),
29	Prefix {
30		op: PrefixOperator,
31		expr: Box<Expr>,
32	},
33	Postfix {
34		expr: Box<Expr>,
35		op: PostfixOperator,
36	},
37	Binary {
38		left: Box<Expr>,
39		op: BinaryOperator,
40		right: Box<Expr>,
41	},
42	FunctionCall(Box<FunctionCall>),
43	Closure(Box<Closure>),
44	Break,
45	Continue,
46	Throw(Box<Expr>),
47	Return(Box<OutputStatement>),
48	IfElse(Box<IfelseStatement>),
49	Select(Box<SelectStatement>),
50	Create(Box<CreateStatement>),
51	Update(Box<UpdateStatement>),
52	Delete(Box<DeleteStatement>),
53	Relate(Box<RelateStatement>),
54	Insert(Box<InsertStatement>),
55	Define(Box<DefineStatement>),
56	Remove(Box<RemoveStatement>),
57	Rebuild(Box<RebuildStatement>),
58	Upsert(Box<UpsertStatement>),
59	Alter(Box<AlterStatement>),
60	Info(Box<InfoStatement>),
61	Foreach(Box<ForeachStatement>),
62	Let(Box<SetStatement>),
63	Sleep(Box<SleepStatement>),
64	Explain {
65		format: ExplainFormat,
66		analyze: bool,
67		statement: Box<Expr>,
68	},
69}
70impl Expr {
71	pub fn to_idiom(&self) -> Idiom {
72		match self {
73			Expr::Idiom(i) => i.simplify(),
74			Expr::Param(i) => Idiom::field(i.clone().to_string()),
75			Expr::FunctionCall(x) => x.receiver.to_idiom(),
76			Expr::Literal(l) => match l {
77				Literal::String(s) => Idiom::field(s.clone()),
78				Literal::Datetime(d) => Idiom::field(d.to_string()),
79				x => Idiom::field(x.to_sql()),
80			},
81			x => Idiom::field(x.to_sql()),
82		}
83	}
84	pub fn from_public_value(value: PublicValue) -> Self {
85		match value {
86			PublicValue::None => Expr::Literal(Literal::None),
87			PublicValue::Null => Expr::Literal(Literal::Null),
88			PublicValue::Bool(x) => Expr::Literal(Literal::Bool(x)),
89			PublicValue::Number(PublicNumber::Float(x)) => Expr::Literal(Literal::Float(x)),
90			PublicValue::Number(PublicNumber::Int(x)) => Expr::Literal(Literal::Integer(x)),
91			PublicValue::Number(PublicNumber::Decimal(x)) => Expr::Literal(Literal::Decimal(x)),
92			PublicValue::String(x) => Expr::Literal(Literal::String(x)),
93			PublicValue::Bytes(x) => Expr::Literal(Literal::Bytes(x)),
94			PublicValue::Regex(x) => Expr::Literal(Literal::Regex(x)),
95			PublicValue::Table(x) => Expr::Table(x.to_string()),
96			PublicValue::RecordId(PublicRecordId { table, key }) => {
97				Expr::Literal(Literal::RecordId(RecordIdLit {
98					table: table.to_string(),
99					key: RecordIdKeyLit::from_record_id_key(key),
100				}))
101			}
102			PublicValue::Array(x) => Expr::Literal(Literal::Array(
103				x.into_iter().map(Expr::from_public_value).collect(),
104			)),
105			PublicValue::Set(x) => Expr::Literal(Literal::Array(
106				x.into_iter().map(Expr::from_public_value).collect(),
107			)),
108			PublicValue::Object(x) => Expr::Literal(Literal::Object(
109				x.into_iter()
110					.map(|(k, v)| ObjectEntry {
111						key: k,
112						value: Expr::from_public_value(v),
113					})
114					.collect(),
115			)),
116			PublicValue::Duration(x) => Expr::Literal(Literal::Duration(x)),
117			PublicValue::Datetime(x) => Expr::Literal(Literal::Datetime(x)),
118			PublicValue::Uuid(x) => Expr::Literal(Literal::Uuid(x)),
119			PublicValue::Geometry(x) => Expr::Literal(Literal::Geometry(x)),
120			PublicValue::File(x) => Expr::Literal(Literal::File(PublicFile::new(x.bucket, x.key))),
121			PublicValue::Range(x) => convert_public_range_to_literal(*x),
122		}
123	}
124	/// Returns if this expression needs to be parenthesized when inside another expression.
125	pub fn needs_parentheses(&self) -> bool {
126		match self {
127			Expr::Literal(Literal::UnboundedRange | Literal::RecordId(_))
128			| Expr::Closure(_)
129			| Expr::Break
130			| Expr::Continue
131			| Expr::Throw(_)
132			| Expr::Return(_)
133			| Expr::IfElse(_)
134			| Expr::Select(_)
135			| Expr::Create(_)
136			| Expr::Update(_)
137			| Expr::Delete(_)
138			| Expr::Relate(_)
139			| Expr::Insert(_)
140			| Expr::Define(_)
141			| Expr::Remove(_)
142			| Expr::Rebuild(_)
143			| Expr::Upsert(_)
144			| Expr::Alter(_)
145			| Expr::Info(_)
146			| Expr::Foreach(_)
147			| Expr::Let(_)
148			| Expr::Sleep(_)
149			| Expr::Explain { .. } => true,
150			Expr::Postfix { op, .. } => {
151				matches!(
152					op,
153					PostfixOperator::Range
154						| PostfixOperator::RangeSkip
155						| PostfixOperator::MethodCall(_, _)
156						| PostfixOperator::Call(_)
157				)
158			}
159			Expr::Literal(_)
160			| Expr::Param(_)
161			| Expr::Idiom(_)
162			| Expr::Table(_)
163			| Expr::Mock(_)
164			| Expr::Block(_)
165			| Expr::Constant(_)
166			| Expr::Prefix { .. }
167			| Expr::Binary { .. }
168			| Expr::FunctionCall(_) => false,
169		}
170	}
171	/// Returns true if there is a `NONE` or `NULL` value in the left most spot when formatting.
172	/// returns true for `NONE + 1`, `NULL()`, `NONE`, `NULL..` etc.
173	///
174	/// Required for proper formatting when `NONE` can conflict with a clause.
175	pub fn has_left_none_null(&self) -> bool {
176		match self {
177			Expr::Literal(Literal::None) | Expr::Literal(Literal::Null) => true,
178			Expr::Binary { left: expr, .. } | Expr::Postfix { expr, .. } => {
179				expr.has_left_none_null()
180			}
181			Expr::Idiom(x) => {
182				if let Some(Part::Start(x)) = x.0.first() {
183					x.has_left_none_null()
184				} else {
185					false
186				}
187			}
188			_ => false,
189		}
190	}
191	pub fn has_left_minus(&self) -> bool {
192		match self {
193			Expr::Prefix {
194				op: PrefixOperator::Negate,
195				..
196			} => true,
197			Expr::Postfix { expr, .. } | Expr::Binary { left: expr, .. } => expr.has_left_minus(),
198			Expr::Literal(Literal::Integer(x)) => x.is_negative(),
199			Expr::Literal(Literal::Float(x)) => x.is_sign_negative(),
200			Expr::Literal(Literal::Decimal(x)) => x.is_sign_negative(),
201			Expr::Idiom(x) => {
202				if let Some(x) = x.0.first()
203					&& let Part::Graph(lookup) = x
204					&& let LookupKind::Graph(Dir::Out) = lookup.kind
205				{
206					return true;
207				}
208				false
209			}
210			_ => false,
211		}
212	}
213	pub fn has_left_idiom(&self) -> bool {
214		match self {
215			Expr::Idiom(_) => true,
216			Expr::Postfix { expr, .. } | Expr::Binary { left: expr, .. } => expr.has_left_idiom(),
217			_ => false,
218		}
219	}
220}
221fn convert_public_geometry_to_internal(
222	geom: surrealdb_types::Geometry,
223) -> crate::compat::val::Geometry {
224	match geom {
225		surrealdb_types::Geometry::Point(p) => crate::compat::val::Geometry::Point(p),
226		surrealdb_types::Geometry::Line(l) => crate::compat::val::Geometry::Line(l),
227		surrealdb_types::Geometry::Polygon(p) => crate::compat::val::Geometry::Polygon(p),
228		surrealdb_types::Geometry::MultiPoint(mp) => crate::compat::val::Geometry::MultiPoint(mp),
229		surrealdb_types::Geometry::MultiLine(ml) => crate::compat::val::Geometry::MultiLine(ml),
230		surrealdb_types::Geometry::MultiPolygon(mp) => {
231			crate::compat::val::Geometry::MultiPolygon(mp)
232		}
233		surrealdb_types::Geometry::Collection(c) => crate::compat::val::Geometry::Collection(
234			c.into_iter()
235				.map(convert_public_geometry_to_internal)
236				.collect(),
237		),
238	}
239}
240fn convert_public_range_to_literal(range: surrealdb_types::Range) -> Expr {
241	use crate::upstream::sql::literal::Literal;
242	use crate::upstream::sql::operator::BinaryOperator;
243	let range = range.into_inner();
244	let op = match (&range.0, &range.1) {
245		(std::ops::Bound::Included(_), std::ops::Bound::Included(_)) => {
246			BinaryOperator::RangeInclusive
247		}
248		_ => BinaryOperator::Range,
249	};
250	let start_expr = match range.0 {
251		std::ops::Bound::Included(v) => Expr::from_public_value(v),
252		std::ops::Bound::Excluded(v) => Expr::from_public_value(v),
253		std::ops::Bound::Unbounded => Expr::Literal(Literal::None),
254	};
255	let end_expr = match range.1 {
256		std::ops::Bound::Included(v) => Expr::from_public_value(v),
257		std::ops::Bound::Excluded(v) => Expr::from_public_value(v),
258		std::ops::Bound::Unbounded => Expr::Literal(Literal::None),
259	};
260	Expr::Binary {
261		left: Box::new(start_expr),
262		op,
263		right: Box::new(end_expr),
264	}
265}
266pub fn convert_public_value_to_internal(
267	value: surrealdb_types::Value,
268) -> crate::compat::val::Value {
269	match value {
270		surrealdb_types::Value::None => crate::compat::val::Value::None,
271		surrealdb_types::Value::Null => crate::compat::val::Value::Null,
272		surrealdb_types::Value::Bool(b) => crate::compat::val::Value::Bool(b),
273		surrealdb_types::Value::Number(n) => match n {
274			surrealdb_types::Number::Int(i) => {
275				crate::compat::val::Value::Number(crate::compat::val::Number::Int(i))
276			}
277			surrealdb_types::Number::Float(f) => {
278				crate::compat::val::Value::Number(crate::compat::val::Number::Float(f))
279			}
280			surrealdb_types::Number::Decimal(d) => {
281				crate::compat::val::Value::Number(crate::compat::val::Number::Decimal(d))
282			}
283		},
284		surrealdb_types::Value::String(s) => crate::compat::val::Value::String(s),
285		surrealdb_types::Value::Duration(d) => crate::compat::val::Value::Duration(d),
286		surrealdb_types::Value::Datetime(dt) => crate::compat::val::Value::Datetime(dt),
287		surrealdb_types::Value::Uuid(u) => crate::compat::val::Value::Uuid(u),
288		surrealdb_types::Value::Array(a) => {
289			crate::compat::val::Value::Array(crate::compat::val::Array::from(
290				a.into_iter()
291					.map(convert_public_value_to_internal)
292					.collect::<Vec<_>>(),
293			))
294		}
295		surrealdb_types::Value::Set(s) => {
296			crate::compat::val::Value::Set(crate::compat::val::Set::from(
297				s.into_iter()
298					.map(convert_public_value_to_internal)
299					.collect::<std::collections::BTreeSet<_>>(),
300			))
301		}
302		surrealdb_types::Value::Object(o) => {
303			crate::compat::val::Value::Object(crate::compat::val::Object::from(
304				o.into_iter()
305					.map(|(k, v)| (k, convert_public_value_to_internal(v)))
306					.collect::<std::collections::BTreeMap<_, _>>(),
307			))
308		}
309		surrealdb_types::Value::Geometry(g) => {
310			crate::compat::val::Value::Geometry(convert_public_geometry_to_internal(g))
311		}
312		surrealdb_types::Value::Bytes(b) => crate::compat::val::Value::Bytes(b),
313		surrealdb_types::Value::Table(t) => crate::compat::val::Value::Table(t.into()),
314		surrealdb_types::Value::RecordId(PublicRecordId { table, key }) => {
315			let key = convert_public_record_id_key_to_internal(key);
316			crate::compat::val::Value::RecordId(crate::compat::val::RecordId {
317				table: table.into(),
318				key,
319			})
320		}
321		surrealdb_types::Value::File(f) => {
322			crate::compat::val::Value::File(crate::compat::val::File {
323				bucket: f.bucket,
324				key: f.key,
325			})
326		}
327		surrealdb_types::Value::Range(r) => {
328			crate::compat::val::Value::Range(Box::new(crate::compat::val::Range {
329				start: match r.start {
330					Bound::Included(v) => Bound::Included(convert_public_value_to_internal(v)),
331					Bound::Excluded(v) => Bound::Excluded(convert_public_value_to_internal(v)),
332					Bound::Unbounded => Bound::Unbounded,
333				},
334				end: match r.end {
335					Bound::Included(v) => Bound::Included(convert_public_value_to_internal(v)),
336					Bound::Excluded(v) => Bound::Excluded(convert_public_value_to_internal(v)),
337					Bound::Unbounded => Bound::Unbounded,
338				},
339			}))
340		}
341		surrealdb_types::Value::Regex(r) => crate::compat::val::Value::Regex(r),
342	}
343}
344fn convert_public_record_id_key_to_internal(
345	key: surrealdb_types::RecordIdKey,
346) -> crate::compat::val::RecordIdKey {
347	match key {
348		surrealdb_types::RecordIdKey::Number(n) => crate::compat::val::RecordIdKey::Number(n),
349		surrealdb_types::RecordIdKey::String(s) => crate::compat::val::RecordIdKey::String(s),
350		surrealdb_types::RecordIdKey::Uuid(u) => crate::compat::val::RecordIdKey::Uuid(u),
351		surrealdb_types::RecordIdKey::Array(a) => {
352			crate::compat::val::RecordIdKey::Array(crate::compat::val::Array::from(
353				a.into_iter()
354					.map(convert_public_value_to_internal)
355					.collect::<Vec<_>>(),
356			))
357		}
358		surrealdb_types::RecordIdKey::Object(o) => {
359			crate::compat::val::RecordIdKey::Object(crate::compat::val::Object::from(
360				o.into_iter()
361					.map(|(k, v)| (k, convert_public_value_to_internal(v)))
362					.collect::<std::collections::BTreeMap<_, _>>(),
363			))
364		}
365		surrealdb_types::RecordIdKey::Range(r) => {
366			crate::compat::val::RecordIdKey::Range(Box::new(crate::compat::val::RecordIdKeyRange {
367				start: match r.start {
368					Bound::Included(k) => {
369						Bound::Included(convert_public_record_id_key_to_internal(k))
370					}
371					Bound::Excluded(k) => {
372						Bound::Excluded(convert_public_record_id_key_to_internal(k))
373					}
374					Bound::Unbounded => Bound::Unbounded,
375				},
376				end: match r.end {
377					Bound::Included(k) => {
378						Bound::Included(convert_public_record_id_key_to_internal(k))
379					}
380					Bound::Excluded(k) => {
381						Bound::Excluded(convert_public_record_id_key_to_internal(k))
382					}
383					Bound::Unbounded => Bound::Unbounded,
384				},
385			}))
386		}
387	}
388}
389impl ToSql for Expr {
390	fn fmt_sql(&self, f: &mut String, fmt: SqlFormat) {
391		match self {
392			Expr::Literal(literal) => literal.fmt_sql(f, fmt),
393			Expr::Param(param) => param.fmt_sql(f, fmt),
394			Expr::Idiom(idiom) => idiom.fmt_sql(f, fmt),
395			Expr::Table(ident) => write_sql!(f, fmt, "{}", EscapeIdent(ident)),
396			Expr::Mock(mock) => mock.fmt_sql(f, fmt),
397			Expr::Block(block) => block.fmt_sql(f, fmt),
398			Expr::Constant(constant) => constant.fmt_sql(f, fmt),
399			Expr::Prefix { op, expr } => {
400				let expr_bp = BindingPower::for_expr(expr);
401				let op_bp = BindingPower::for_prefix_operator(op);
402				if expr.needs_parentheses()
403					|| expr_bp < op_bp
404					|| expr_bp == op_bp && matches!(expr_bp, BindingPower::Range)
405					|| *op == PrefixOperator::Negate && expr.has_left_minus()
406				{
407					write_sql!(f, fmt, "{op}({expr})");
408				} else {
409					write_sql!(f, fmt, "{op}{expr}");
410				}
411			}
412			Expr::Postfix { expr, op } => {
413				let expr_bp = BindingPower::for_expr(expr);
414				let op_bp = BindingPower::for_postfix_operator(op);
415				if expr.needs_parentheses()
416					|| expr_bp < op_bp
417					|| expr_bp == op_bp && matches!(expr_bp, BindingPower::Range)
418					|| matches!(op, PostfixOperator::Call(_))
419				{
420					write_sql!(f, fmt, "({expr}){op}");
421				} else {
422					write_sql!(f, fmt, "{expr}{op}");
423				}
424			}
425			Expr::Binary { left, op, right } => {
426				let op_bp = BindingPower::for_binary_operator(op);
427				let left_bp = BindingPower::for_expr(left);
428				let right_bp = BindingPower::for_expr(right);
429				if left.needs_parentheses()
430					|| left_bp < op_bp
431					|| left_bp == op_bp
432						&& matches!(
433							left_bp,
434							BindingPower::Range | BindingPower::Relation | BindingPower::Equality
435						) {
436					write_sql!(f, fmt, "({left})");
437				} else {
438					write_sql!(f, fmt, "{left}");
439				}
440				if matches!(
441					op,
442					BinaryOperator::Range
443						| BinaryOperator::RangeSkip
444						| BinaryOperator::RangeInclusive
445						| BinaryOperator::RangeSkipInclusive
446				) {
447					op.fmt_sql(f, fmt);
448				} else {
449					f.push(' ');
450					op.fmt_sql(f, fmt);
451					f.push(' ');
452				}
453				if right.needs_parentheses()
454					|| right_bp < op_bp
455					|| right_bp == op_bp
456						&& matches!(
457							right_bp,
458							BindingPower::Range | BindingPower::Relation | BindingPower::Equality
459						) {
460					write_sql!(f, fmt, "({right})");
461				} else {
462					write_sql!(f, fmt, "{right}");
463				}
464			}
465			Expr::FunctionCall(function_call) => function_call.fmt_sql(f, fmt),
466			Expr::Closure(closure) => closure.fmt_sql(f, fmt),
467			Expr::Break => f.push_str("BREAK"),
468			Expr::Continue => f.push_str("CONTINUE"),
469			Expr::Return(x) => x.fmt_sql(f, fmt),
470			Expr::Throw(expr) => {
471				write_sql!(f, fmt, "THROW {}", CoverStmts(expr.as_ref()))
472			}
473			Expr::IfElse(s) => s.fmt_sql(f, fmt),
474			Expr::Select(s) => s.fmt_sql(f, fmt),
475			Expr::Create(s) => s.fmt_sql(f, fmt),
476			Expr::Update(s) => s.fmt_sql(f, fmt),
477			Expr::Delete(s) => s.fmt_sql(f, fmt),
478			Expr::Relate(s) => s.fmt_sql(f, fmt),
479			Expr::Insert(s) => s.fmt_sql(f, fmt),
480			Expr::Define(s) => s.fmt_sql(f, fmt),
481			Expr::Remove(s) => s.fmt_sql(f, fmt),
482			Expr::Rebuild(s) => s.fmt_sql(f, fmt),
483			Expr::Upsert(s) => s.fmt_sql(f, fmt),
484			Expr::Alter(s) => s.fmt_sql(f, fmt),
485			Expr::Info(s) => s.fmt_sql(f, fmt),
486			Expr::Foreach(s) => s.fmt_sql(f, fmt),
487			Expr::Let(s) => s.fmt_sql(f, fmt),
488			Expr::Sleep(s) => s.fmt_sql(f, fmt),
489			Expr::Explain {
490				format: explain_format,
491				analyze,
492				statement,
493			} => {
494				f.push_str("EXPLAIN");
495				if *analyze {
496					f.push_str(" ANALYZE");
497				}
498				match explain_format {
499					ExplainFormat::Text => f.push_str(" FORMAT TEXT"),
500					ExplainFormat::Json => f.push_str(" FORMAT JSON"),
501				}
502				f.push(' ');
503				statement.fmt_sql(f, fmt);
504			}
505		}
506	}
507}