1use crate::compat::types::{PublicFile, PublicNumber, PublicRecordId, PublicValue};
2use crate::upstream::fmt::{CoverStmts, EscapeIdent};
3use crate::upstream::sql::ast::ExplainFormat;
4use crate::upstream::sql::literal::ObjectEntry;
5use crate::upstream::sql::lookup::LookupKind;
6use crate::upstream::sql::operator::BindingPower;
7use crate::upstream::sql::statements::{
8 AlterStatement, CreateStatement, DefineStatement, DeleteStatement, ForeachStatement,
9 IfelseStatement, InfoStatement, InsertStatement, OutputStatement, RebuildStatement,
10 RelateStatement, RemoveStatement, SelectStatement, SetStatement, SleepStatement,
11 UpdateStatement, UpsertStatement,
12};
13use crate::upstream::sql::{
14 BinaryOperator, Block, Closure, Constant, Dir, FunctionCall, Idiom, Literal, Mock, Param, Part,
15 PostfixOperator, PrefixOperator, RecordIdKeyLit, RecordIdLit,
16};
17use std::ops::Bound;
18use surrealdb_types::{SqlFormat, ToSql, write_sql};
19#[derive(Clone, Debug, Eq, PartialEq)]
20#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
21pub enum Expr {
22 Literal(Literal),
23 Param(Param),
24 Idiom(Idiom),
25 Table(String),
26 Mock(Mock),
27 Block(Box<Block>),
28 Constant(Constant),
29 Prefix {
30 op: PrefixOperator,
31 expr: Box<Expr>,
32 },
33 Postfix {
34 expr: Box<Expr>,
35 op: PostfixOperator,
36 },
37 Binary {
38 left: Box<Expr>,
39 op: BinaryOperator,
40 right: Box<Expr>,
41 },
42 FunctionCall(Box<FunctionCall>),
43 Closure(Box<Closure>),
44 Break,
45 Continue,
46 Throw(Box<Expr>),
47 Return(Box<OutputStatement>),
48 IfElse(Box<IfelseStatement>),
49 Select(Box<SelectStatement>),
50 Create(Box<CreateStatement>),
51 Update(Box<UpdateStatement>),
52 Delete(Box<DeleteStatement>),
53 Relate(Box<RelateStatement>),
54 Insert(Box<InsertStatement>),
55 Define(Box<DefineStatement>),
56 Remove(Box<RemoveStatement>),
57 Rebuild(Box<RebuildStatement>),
58 Upsert(Box<UpsertStatement>),
59 Alter(Box<AlterStatement>),
60 Info(Box<InfoStatement>),
61 Foreach(Box<ForeachStatement>),
62 Let(Box<SetStatement>),
63 Sleep(Box<SleepStatement>),
64 Explain {
65 format: ExplainFormat,
66 analyze: bool,
67 statement: Box<Expr>,
68 },
69}
70impl Expr {
71 pub fn to_idiom(&self) -> Idiom {
72 match self {
73 Expr::Idiom(i) => i.simplify(),
74 Expr::Param(i) => Idiom::field(i.clone().to_string()),
75 Expr::FunctionCall(x) => x.receiver.to_idiom(),
76 Expr::Literal(l) => match l {
77 Literal::String(s) => Idiom::field(s.clone()),
78 Literal::Datetime(d) => Idiom::field(d.to_string()),
79 x => Idiom::field(x.to_sql()),
80 },
81 x => Idiom::field(x.to_sql()),
82 }
83 }
84 pub fn from_public_value(value: PublicValue) -> Self {
85 match value {
86 PublicValue::None => Expr::Literal(Literal::None),
87 PublicValue::Null => Expr::Literal(Literal::Null),
88 PublicValue::Bool(x) => Expr::Literal(Literal::Bool(x)),
89 PublicValue::Number(PublicNumber::Float(x)) => Expr::Literal(Literal::Float(x)),
90 PublicValue::Number(PublicNumber::Int(x)) => Expr::Literal(Literal::Integer(x)),
91 PublicValue::Number(PublicNumber::Decimal(x)) => Expr::Literal(Literal::Decimal(x)),
92 PublicValue::String(x) => Expr::Literal(Literal::String(x)),
93 PublicValue::Bytes(x) => Expr::Literal(Literal::Bytes(x)),
94 PublicValue::Regex(x) => Expr::Literal(Literal::Regex(x)),
95 PublicValue::Table(x) => Expr::Table(x.to_string()),
96 PublicValue::RecordId(PublicRecordId { table, key }) => {
97 Expr::Literal(Literal::RecordId(RecordIdLit {
98 table: table.to_string(),
99 key: RecordIdKeyLit::from_record_id_key(key),
100 }))
101 }
102 PublicValue::Array(x) => Expr::Literal(Literal::Array(
103 x.into_iter().map(Expr::from_public_value).collect(),
104 )),
105 PublicValue::Set(x) => Expr::Literal(Literal::Array(
106 x.into_iter().map(Expr::from_public_value).collect(),
107 )),
108 PublicValue::Object(x) => Expr::Literal(Literal::Object(
109 x.into_iter()
110 .map(|(k, v)| ObjectEntry {
111 key: k,
112 value: Expr::from_public_value(v),
113 })
114 .collect(),
115 )),
116 PublicValue::Duration(x) => Expr::Literal(Literal::Duration(x)),
117 PublicValue::Datetime(x) => Expr::Literal(Literal::Datetime(x)),
118 PublicValue::Uuid(x) => Expr::Literal(Literal::Uuid(x)),
119 PublicValue::Geometry(x) => Expr::Literal(Literal::Geometry(x)),
120 PublicValue::File(x) => Expr::Literal(Literal::File(PublicFile::new(x.bucket, x.key))),
121 PublicValue::Range(x) => convert_public_range_to_literal(*x),
122 }
123 }
124 pub fn needs_parentheses(&self) -> bool {
126 match self {
127 Expr::Literal(Literal::UnboundedRange | Literal::RecordId(_))
128 | Expr::Closure(_)
129 | Expr::Break
130 | Expr::Continue
131 | Expr::Throw(_)
132 | Expr::Return(_)
133 | Expr::IfElse(_)
134 | Expr::Select(_)
135 | Expr::Create(_)
136 | Expr::Update(_)
137 | Expr::Delete(_)
138 | Expr::Relate(_)
139 | Expr::Insert(_)
140 | Expr::Define(_)
141 | Expr::Remove(_)
142 | Expr::Rebuild(_)
143 | Expr::Upsert(_)
144 | Expr::Alter(_)
145 | Expr::Info(_)
146 | Expr::Foreach(_)
147 | Expr::Let(_)
148 | Expr::Sleep(_)
149 | Expr::Explain { .. } => true,
150 Expr::Postfix { op, .. } => {
151 matches!(
152 op,
153 PostfixOperator::Range
154 | PostfixOperator::RangeSkip
155 | PostfixOperator::MethodCall(_, _)
156 | PostfixOperator::Call(_)
157 )
158 }
159 Expr::Literal(_)
160 | Expr::Param(_)
161 | Expr::Idiom(_)
162 | Expr::Table(_)
163 | Expr::Mock(_)
164 | Expr::Block(_)
165 | Expr::Constant(_)
166 | Expr::Prefix { .. }
167 | Expr::Binary { .. }
168 | Expr::FunctionCall(_) => false,
169 }
170 }
171 pub fn has_left_none_null(&self) -> bool {
176 match self {
177 Expr::Literal(Literal::None) | Expr::Literal(Literal::Null) => true,
178 Expr::Binary { left: expr, .. } | Expr::Postfix { expr, .. } => {
179 expr.has_left_none_null()
180 }
181 Expr::Idiom(x) => {
182 if let Some(Part::Start(x)) = x.0.first() {
183 x.has_left_none_null()
184 } else {
185 false
186 }
187 }
188 _ => false,
189 }
190 }
191 pub fn has_left_minus(&self) -> bool {
192 match self {
193 Expr::Prefix {
194 op: PrefixOperator::Negate,
195 ..
196 } => true,
197 Expr::Postfix { expr, .. } | Expr::Binary { left: expr, .. } => expr.has_left_minus(),
198 Expr::Literal(Literal::Integer(x)) => x.is_negative(),
199 Expr::Literal(Literal::Float(x)) => x.is_sign_negative(),
200 Expr::Literal(Literal::Decimal(x)) => x.is_sign_negative(),
201 Expr::Idiom(x) => {
202 if let Some(x) = x.0.first()
203 && let Part::Graph(lookup) = x
204 && let LookupKind::Graph(Dir::Out) = lookup.kind
205 {
206 return true;
207 }
208 false
209 }
210 _ => false,
211 }
212 }
213 pub fn has_left_idiom(&self) -> bool {
214 match self {
215 Expr::Idiom(_) => true,
216 Expr::Postfix { expr, .. } | Expr::Binary { left: expr, .. } => expr.has_left_idiom(),
217 _ => false,
218 }
219 }
220}
221fn convert_public_geometry_to_internal(
222 geom: surrealdb_types::Geometry,
223) -> crate::compat::val::Geometry {
224 match geom {
225 surrealdb_types::Geometry::Point(p) => crate::compat::val::Geometry::Point(p),
226 surrealdb_types::Geometry::Line(l) => crate::compat::val::Geometry::Line(l),
227 surrealdb_types::Geometry::Polygon(p) => crate::compat::val::Geometry::Polygon(p),
228 surrealdb_types::Geometry::MultiPoint(mp) => crate::compat::val::Geometry::MultiPoint(mp),
229 surrealdb_types::Geometry::MultiLine(ml) => crate::compat::val::Geometry::MultiLine(ml),
230 surrealdb_types::Geometry::MultiPolygon(mp) => {
231 crate::compat::val::Geometry::MultiPolygon(mp)
232 }
233 surrealdb_types::Geometry::Collection(c) => crate::compat::val::Geometry::Collection(
234 c.into_iter()
235 .map(convert_public_geometry_to_internal)
236 .collect(),
237 ),
238 }
239}
240fn convert_public_range_to_literal(range: surrealdb_types::Range) -> Expr {
241 use crate::upstream::sql::literal::Literal;
242 use crate::upstream::sql::operator::BinaryOperator;
243 let range = range.into_inner();
244 let op = match (&range.0, &range.1) {
245 (std::ops::Bound::Included(_), std::ops::Bound::Included(_)) => {
246 BinaryOperator::RangeInclusive
247 }
248 _ => BinaryOperator::Range,
249 };
250 let start_expr = match range.0 {
251 std::ops::Bound::Included(v) => Expr::from_public_value(v),
252 std::ops::Bound::Excluded(v) => Expr::from_public_value(v),
253 std::ops::Bound::Unbounded => Expr::Literal(Literal::None),
254 };
255 let end_expr = match range.1 {
256 std::ops::Bound::Included(v) => Expr::from_public_value(v),
257 std::ops::Bound::Excluded(v) => Expr::from_public_value(v),
258 std::ops::Bound::Unbounded => Expr::Literal(Literal::None),
259 };
260 Expr::Binary {
261 left: Box::new(start_expr),
262 op,
263 right: Box::new(end_expr),
264 }
265}
266pub fn convert_public_value_to_internal(
267 value: surrealdb_types::Value,
268) -> crate::compat::val::Value {
269 match value {
270 surrealdb_types::Value::None => crate::compat::val::Value::None,
271 surrealdb_types::Value::Null => crate::compat::val::Value::Null,
272 surrealdb_types::Value::Bool(b) => crate::compat::val::Value::Bool(b),
273 surrealdb_types::Value::Number(n) => match n {
274 surrealdb_types::Number::Int(i) => {
275 crate::compat::val::Value::Number(crate::compat::val::Number::Int(i))
276 }
277 surrealdb_types::Number::Float(f) => {
278 crate::compat::val::Value::Number(crate::compat::val::Number::Float(f))
279 }
280 surrealdb_types::Number::Decimal(d) => {
281 crate::compat::val::Value::Number(crate::compat::val::Number::Decimal(d))
282 }
283 },
284 surrealdb_types::Value::String(s) => crate::compat::val::Value::String(s),
285 surrealdb_types::Value::Duration(d) => crate::compat::val::Value::Duration(d),
286 surrealdb_types::Value::Datetime(dt) => crate::compat::val::Value::Datetime(dt),
287 surrealdb_types::Value::Uuid(u) => crate::compat::val::Value::Uuid(u),
288 surrealdb_types::Value::Array(a) => {
289 crate::compat::val::Value::Array(crate::compat::val::Array::from(
290 a.into_iter()
291 .map(convert_public_value_to_internal)
292 .collect::<Vec<_>>(),
293 ))
294 }
295 surrealdb_types::Value::Set(s) => {
296 crate::compat::val::Value::Set(crate::compat::val::Set::from(
297 s.into_iter()
298 .map(convert_public_value_to_internal)
299 .collect::<std::collections::BTreeSet<_>>(),
300 ))
301 }
302 surrealdb_types::Value::Object(o) => {
303 crate::compat::val::Value::Object(crate::compat::val::Object::from(
304 o.into_iter()
305 .map(|(k, v)| (k, convert_public_value_to_internal(v)))
306 .collect::<std::collections::BTreeMap<_, _>>(),
307 ))
308 }
309 surrealdb_types::Value::Geometry(g) => {
310 crate::compat::val::Value::Geometry(convert_public_geometry_to_internal(g))
311 }
312 surrealdb_types::Value::Bytes(b) => crate::compat::val::Value::Bytes(b),
313 surrealdb_types::Value::Table(t) => crate::compat::val::Value::Table(t.into()),
314 surrealdb_types::Value::RecordId(PublicRecordId { table, key }) => {
315 let key = convert_public_record_id_key_to_internal(key);
316 crate::compat::val::Value::RecordId(crate::compat::val::RecordId {
317 table: table.into(),
318 key,
319 })
320 }
321 surrealdb_types::Value::File(f) => {
322 crate::compat::val::Value::File(crate::compat::val::File {
323 bucket: f.bucket,
324 key: f.key,
325 })
326 }
327 surrealdb_types::Value::Range(r) => {
328 crate::compat::val::Value::Range(Box::new(crate::compat::val::Range {
329 start: match r.start {
330 Bound::Included(v) => Bound::Included(convert_public_value_to_internal(v)),
331 Bound::Excluded(v) => Bound::Excluded(convert_public_value_to_internal(v)),
332 Bound::Unbounded => Bound::Unbounded,
333 },
334 end: match r.end {
335 Bound::Included(v) => Bound::Included(convert_public_value_to_internal(v)),
336 Bound::Excluded(v) => Bound::Excluded(convert_public_value_to_internal(v)),
337 Bound::Unbounded => Bound::Unbounded,
338 },
339 }))
340 }
341 surrealdb_types::Value::Regex(r) => crate::compat::val::Value::Regex(r),
342 }
343}
344fn convert_public_record_id_key_to_internal(
345 key: surrealdb_types::RecordIdKey,
346) -> crate::compat::val::RecordIdKey {
347 match key {
348 surrealdb_types::RecordIdKey::Number(n) => crate::compat::val::RecordIdKey::Number(n),
349 surrealdb_types::RecordIdKey::String(s) => crate::compat::val::RecordIdKey::String(s),
350 surrealdb_types::RecordIdKey::Uuid(u) => crate::compat::val::RecordIdKey::Uuid(u),
351 surrealdb_types::RecordIdKey::Array(a) => {
352 crate::compat::val::RecordIdKey::Array(crate::compat::val::Array::from(
353 a.into_iter()
354 .map(convert_public_value_to_internal)
355 .collect::<Vec<_>>(),
356 ))
357 }
358 surrealdb_types::RecordIdKey::Object(o) => {
359 crate::compat::val::RecordIdKey::Object(crate::compat::val::Object::from(
360 o.into_iter()
361 .map(|(k, v)| (k, convert_public_value_to_internal(v)))
362 .collect::<std::collections::BTreeMap<_, _>>(),
363 ))
364 }
365 surrealdb_types::RecordIdKey::Range(r) => {
366 crate::compat::val::RecordIdKey::Range(Box::new(crate::compat::val::RecordIdKeyRange {
367 start: match r.start {
368 Bound::Included(k) => {
369 Bound::Included(convert_public_record_id_key_to_internal(k))
370 }
371 Bound::Excluded(k) => {
372 Bound::Excluded(convert_public_record_id_key_to_internal(k))
373 }
374 Bound::Unbounded => Bound::Unbounded,
375 },
376 end: match r.end {
377 Bound::Included(k) => {
378 Bound::Included(convert_public_record_id_key_to_internal(k))
379 }
380 Bound::Excluded(k) => {
381 Bound::Excluded(convert_public_record_id_key_to_internal(k))
382 }
383 Bound::Unbounded => Bound::Unbounded,
384 },
385 }))
386 }
387 }
388}
389impl ToSql for Expr {
390 fn fmt_sql(&self, f: &mut String, fmt: SqlFormat) {
391 match self {
392 Expr::Literal(literal) => literal.fmt_sql(f, fmt),
393 Expr::Param(param) => param.fmt_sql(f, fmt),
394 Expr::Idiom(idiom) => idiom.fmt_sql(f, fmt),
395 Expr::Table(ident) => write_sql!(f, fmt, "{}", EscapeIdent(ident)),
396 Expr::Mock(mock) => mock.fmt_sql(f, fmt),
397 Expr::Block(block) => block.fmt_sql(f, fmt),
398 Expr::Constant(constant) => constant.fmt_sql(f, fmt),
399 Expr::Prefix { op, expr } => {
400 let expr_bp = BindingPower::for_expr(expr);
401 let op_bp = BindingPower::for_prefix_operator(op);
402 if expr.needs_parentheses()
403 || expr_bp < op_bp
404 || expr_bp == op_bp && matches!(expr_bp, BindingPower::Range)
405 || *op == PrefixOperator::Negate && expr.has_left_minus()
406 {
407 write_sql!(f, fmt, "{op}({expr})");
408 } else {
409 write_sql!(f, fmt, "{op}{expr}");
410 }
411 }
412 Expr::Postfix { expr, op } => {
413 let expr_bp = BindingPower::for_expr(expr);
414 let op_bp = BindingPower::for_postfix_operator(op);
415 if expr.needs_parentheses()
416 || expr_bp < op_bp
417 || expr_bp == op_bp && matches!(expr_bp, BindingPower::Range)
418 || matches!(op, PostfixOperator::Call(_))
419 {
420 write_sql!(f, fmt, "({expr}){op}");
421 } else {
422 write_sql!(f, fmt, "{expr}{op}");
423 }
424 }
425 Expr::Binary { left, op, right } => {
426 let op_bp = BindingPower::for_binary_operator(op);
427 let left_bp = BindingPower::for_expr(left);
428 let right_bp = BindingPower::for_expr(right);
429 if left.needs_parentheses()
430 || left_bp < op_bp
431 || left_bp == op_bp
432 && matches!(
433 left_bp,
434 BindingPower::Range | BindingPower::Relation | BindingPower::Equality
435 ) {
436 write_sql!(f, fmt, "({left})");
437 } else {
438 write_sql!(f, fmt, "{left}");
439 }
440 if matches!(
441 op,
442 BinaryOperator::Range
443 | BinaryOperator::RangeSkip
444 | BinaryOperator::RangeInclusive
445 | BinaryOperator::RangeSkipInclusive
446 ) {
447 op.fmt_sql(f, fmt);
448 } else {
449 f.push(' ');
450 op.fmt_sql(f, fmt);
451 f.push(' ');
452 }
453 if right.needs_parentheses()
454 || right_bp < op_bp
455 || right_bp == op_bp
456 && matches!(
457 right_bp,
458 BindingPower::Range | BindingPower::Relation | BindingPower::Equality
459 ) {
460 write_sql!(f, fmt, "({right})");
461 } else {
462 write_sql!(f, fmt, "{right}");
463 }
464 }
465 Expr::FunctionCall(function_call) => function_call.fmt_sql(f, fmt),
466 Expr::Closure(closure) => closure.fmt_sql(f, fmt),
467 Expr::Break => f.push_str("BREAK"),
468 Expr::Continue => f.push_str("CONTINUE"),
469 Expr::Return(x) => x.fmt_sql(f, fmt),
470 Expr::Throw(expr) => {
471 write_sql!(f, fmt, "THROW {}", CoverStmts(expr.as_ref()))
472 }
473 Expr::IfElse(s) => s.fmt_sql(f, fmt),
474 Expr::Select(s) => s.fmt_sql(f, fmt),
475 Expr::Create(s) => s.fmt_sql(f, fmt),
476 Expr::Update(s) => s.fmt_sql(f, fmt),
477 Expr::Delete(s) => s.fmt_sql(f, fmt),
478 Expr::Relate(s) => s.fmt_sql(f, fmt),
479 Expr::Insert(s) => s.fmt_sql(f, fmt),
480 Expr::Define(s) => s.fmt_sql(f, fmt),
481 Expr::Remove(s) => s.fmt_sql(f, fmt),
482 Expr::Rebuild(s) => s.fmt_sql(f, fmt),
483 Expr::Upsert(s) => s.fmt_sql(f, fmt),
484 Expr::Alter(s) => s.fmt_sql(f, fmt),
485 Expr::Info(s) => s.fmt_sql(f, fmt),
486 Expr::Foreach(s) => s.fmt_sql(f, fmt),
487 Expr::Let(s) => s.fmt_sql(f, fmt),
488 Expr::Sleep(s) => s.fmt_sql(f, fmt),
489 Expr::Explain {
490 format: explain_format,
491 analyze,
492 statement,
493 } => {
494 f.push_str("EXPLAIN");
495 if *analyze {
496 f.push_str(" ANALYZE");
497 }
498 match explain_format {
499 ExplainFormat::Text => f.push_str(" FORMAT TEXT"),
500 ExplainFormat::Json => f.push_str(" FORMAT JSON"),
501 }
502 f.push(' ');
503 statement.fmt_sql(f, fmt);
504 }
505 }
506 }
507}