1use std::collections::HashMap;
10
11use arrow_array::RecordBatch;
12use uni_common::Value;
13use uni_cypher::ast::{BinaryOp, CypherLiteral, Expr, UnaryOp};
14use uni_cypher::locy_ast::{LocyBinaryOp, LocyExpr};
15use uni_locy::{FactRow, LocyError};
16
17pub fn eval_locy_expr(
20 expr: &LocyExpr,
21 bindings: &FactRow,
22 prev_values: Option<&FactRow>,
23) -> Result<Value, LocyError> {
24 match expr {
25 LocyExpr::PrevRef(field) => Ok(prev_values
26 .and_then(|prev| prev.get(field).cloned())
27 .unwrap_or(Value::Null)),
28 LocyExpr::Cypher(cypher_expr) => eval_expr(cypher_expr, bindings),
29 LocyExpr::BinaryOp { left, op, right } => {
30 let l = eval_locy_expr(left, bindings, prev_values)?;
31 let r = eval_locy_expr(right, bindings, prev_values)?;
32 eval_locy_binary_op(&l, op, &r)
33 }
34 LocyExpr::UnaryOp(op, inner) => {
35 let v = eval_locy_expr(inner, bindings, prev_values)?;
36 eval_unary_op(op, &v)
37 }
38 }
39}
40
41pub fn eval_expr(expr: &Expr, bindings: &FactRow) -> Result<Value, LocyError> {
43 match expr {
44 Expr::Literal(lit) => Ok(literal_to_value(lit)),
45 Expr::Variable(name) => Ok(bindings.get(name).cloned().unwrap_or(Value::Null)),
46 Expr::Property(expr, property) => {
47 let base = eval_expr(expr, bindings)?;
48 Ok(get_property(&base, property))
49 }
50 Expr::BinaryOp { left, op, right } => {
51 let l = eval_expr(left, bindings)?;
52 let r = eval_expr(right, bindings)?;
53 eval_binary_op(&l, op, &r)
54 }
55 Expr::UnaryOp { op, expr } => {
56 let v = eval_expr(expr, bindings)?;
57 eval_unary_op(op, &v)
58 }
59 Expr::FunctionCall { name, args, .. } => {
60 let evaluated_args: Result<Vec<Value>, _> =
61 args.iter().map(|a| eval_expr(a, bindings)).collect();
62 eval_function(name, &evaluated_args?)
63 }
64 Expr::Parameter(name) => Ok(bindings.get(name).cloned().unwrap_or(Value::Null)),
65 Expr::IsNull(inner) => {
66 let v = eval_expr(inner, bindings)?;
67 Ok(Value::Bool(v.is_null()))
68 }
69 Expr::IsNotNull(inner) => {
70 let v = eval_expr(inner, bindings)?;
71 Ok(Value::Bool(!v.is_null()))
72 }
73 Expr::List(items) => {
74 let vals: Result<Vec<Value>, _> =
75 items.iter().map(|i| eval_expr(i, bindings)).collect();
76 Ok(Value::List(vals?))
77 }
78 Expr::Map(entries) => {
79 let mut map = HashMap::new();
80 for (k, v) in entries {
81 map.insert(k.clone(), eval_expr(v, bindings)?);
82 }
83 Ok(Value::Map(map))
84 }
85 _ => Err(LocyError::EvaluationError {
86 message: format!("unsupported expression in in-memory evaluation: {expr:?}"),
87 }),
88 }
89}
90
91pub fn eval_aggregate_over_group(
93 func_name: &str,
94 arg_expr: &Expr,
95 group: &[FactRow],
96 rule_name: &str,
97 fold_name: &str,
98) -> Result<Value, LocyError> {
99 let upper = func_name.to_uppercase();
100 match upper.as_str() {
101 "SUM" => {
102 let mut total = 0.0_f64;
103 for row in group {
104 let v = eval_expr(arg_expr, row)?;
105 if let Some(f) = v.as_f64() {
106 total += f;
107 }
108 }
109 if total == total.floor() && total.abs() < i64::MAX as f64 {
110 Ok(Value::Int(total as i64))
111 } else {
112 Ok(Value::Float(total))
113 }
114 }
115 "MSUM" => {
116 let mut total = 0.0_f64;
117 for row in group {
118 let v = eval_expr(arg_expr, row)?;
119 if let Some(f) = v.as_f64() {
120 if f < 0.0 {
121 return Err(LocyError::MsumNegativeValue {
122 rule: rule_name.to_string(),
123 fold: fold_name.to_string(),
124 value: f,
125 });
126 }
127 total += f;
128 }
129 }
130 if total == total.floor() && total.abs() < i64::MAX as f64 {
131 Ok(Value::Int(total as i64))
132 } else {
133 Ok(Value::Float(total))
134 }
135 }
136 "COUNT" | "MCOUNT" => {
137 let count = group
138 .iter()
139 .filter(|row| {
140 eval_expr(arg_expr, row)
141 .map(|v| !v.is_null())
142 .unwrap_or(false)
143 })
144 .count();
145 Ok(Value::Int(count as i64))
146 }
147 "MIN" | "MMIN" => {
148 let mut min_val: Option<Value> = None;
149 for row in group {
150 let v = eval_expr(arg_expr, row)?;
151 if v.is_null() {
152 continue;
153 }
154 min_val = Some(match min_val {
155 None => v,
156 Some(cur) => {
157 if value_less_than(&v, &cur) {
158 v
159 } else {
160 cur
161 }
162 }
163 });
164 }
165 Ok(min_val.unwrap_or(Value::Null))
166 }
167 "MAX" | "MMAX" => {
168 let mut max_val: Option<Value> = None;
169 for row in group {
170 let v = eval_expr(arg_expr, row)?;
171 if v.is_null() {
172 continue;
173 }
174 max_val = Some(match max_val {
175 None => v,
176 Some(cur) => {
177 if value_less_than(&cur, &v) {
178 v
179 } else {
180 cur
181 }
182 }
183 });
184 }
185 Ok(max_val.unwrap_or(Value::Null))
186 }
187 "AVG" => {
188 let mut total = 0.0_f64;
189 let mut count = 0;
190 for row in group {
191 let v = eval_expr(arg_expr, row)?;
192 if let Some(f) = v.as_f64() {
193 total += f;
194 count += 1;
195 }
196 }
197 if count == 0 {
198 Ok(Value::Null)
199 } else {
200 Ok(Value::Float(total / count as f64))
201 }
202 }
203 "COLLECT" => {
204 let mut vals = Vec::new();
205 for row in group {
206 let v = eval_expr(arg_expr, row)?;
207 if !v.is_null() {
208 vals.push(v);
209 }
210 }
211 Ok(Value::List(vals))
212 }
213 _ => Err(LocyError::EvaluationError {
214 message: format!("unknown aggregate function: {func_name}"),
215 }),
216 }
217}
218
219pub(crate) fn literal_to_value(lit: &CypherLiteral) -> Value {
220 match lit {
221 CypherLiteral::Null => Value::Null,
222 CypherLiteral::Bool(b) => Value::Bool(*b),
223 CypherLiteral::Integer(i) => Value::Int(*i),
224 CypherLiteral::Float(f) => Value::Float(*f),
225 CypherLiteral::String(s) => Value::String(s.clone()),
226 CypherLiteral::Bytes(b) => Value::Bytes(b.clone()),
227 }
228}
229
230fn get_property(value: &Value, property: &str) -> Value {
231 match value {
232 Value::Node(n) => n.properties.get(property).cloned().unwrap_or(Value::Null),
233 Value::Edge(e) => e.properties.get(property).cloned().unwrap_or(Value::Null),
234 Value::Map(m) => m.get(property).cloned().unwrap_or(Value::Null),
235 _ => Value::Null,
236 }
237}
238
239fn eval_unary_op(op: &UnaryOp, v: &Value) -> Result<Value, LocyError> {
244 match op {
245 UnaryOp::Not => match v {
246 Value::Bool(b) => Ok(Value::Bool(!b)),
247 Value::Null => Ok(Value::Null),
248 _ => Err(LocyError::TypeError {
249 message: format!("NOT requires boolean, got {v:?}"),
250 }),
251 },
252 UnaryOp::Neg => match v {
253 Value::Int(i) => Ok(Value::Int(-i)),
254 Value::Float(f) => Ok(Value::Float(-f)),
255 Value::Null => Ok(Value::Null),
256 _ => Err(LocyError::TypeError {
257 message: format!("negation requires numeric, got {v:?}"),
258 }),
259 },
260 }
261}
262
263fn eval_locy_binary_op(left: &Value, op: &LocyBinaryOp, right: &Value) -> Result<Value, LocyError> {
264 if left.is_null() || right.is_null() {
265 return Ok(Value::Null);
266 }
267 match op {
268 LocyBinaryOp::Add => numeric_op(left, right, |a, b| a + b, |a, b| a + b),
269 LocyBinaryOp::Sub => numeric_op(left, right, |a, b| a - b, |a, b| a - b),
270 LocyBinaryOp::Mul => numeric_op(left, right, |a, b| a * b, |a, b| a * b),
271 LocyBinaryOp::Div => {
272 let r = right.as_f64().unwrap_or(0.0);
273 if r == 0.0 {
274 return Err(LocyError::EvaluationError {
275 message: "division by zero".to_string(),
276 });
277 }
278 numeric_op(left, right, |a, b| a / b, |a, b| a / b)
279 }
280 LocyBinaryOp::Mod => numeric_op(left, right, |a, b| a % b, |a, b| a % b),
281 LocyBinaryOp::Pow => {
282 let l = left.as_f64().ok_or_else(|| LocyError::TypeError {
283 message: format!("pow requires numeric, got {left:?}"),
284 })?;
285 let r = right.as_f64().ok_or_else(|| LocyError::TypeError {
286 message: format!("pow requires numeric, got {right:?}"),
287 })?;
288 Ok(Value::Float(l.powf(r)))
289 }
290 LocyBinaryOp::And => match (left.as_bool(), right.as_bool()) {
291 (Some(a), Some(b)) => Ok(Value::Bool(a && b)),
292 _ => Ok(Value::Null),
293 },
294 LocyBinaryOp::Or => match (left.as_bool(), right.as_bool()) {
295 (Some(a), Some(b)) => Ok(Value::Bool(a || b)),
296 _ => Ok(Value::Null),
297 },
298 LocyBinaryOp::Xor => match (left.as_bool(), right.as_bool()) {
299 (Some(a), Some(b)) => Ok(Value::Bool(a ^ b)),
300 _ => Ok(Value::Null),
301 },
302 }
303}
304
305fn eval_binary_op(left: &Value, op: &BinaryOp, right: &Value) -> Result<Value, LocyError> {
306 if left.is_null() || right.is_null() {
307 return match op {
308 BinaryOp::Eq => Ok(Value::Bool(left.is_null() && right.is_null())),
309 BinaryOp::NotEq => Ok(Value::Bool(!(left.is_null() && right.is_null()))),
310 _ => Ok(Value::Null),
311 };
312 }
313 match op {
314 BinaryOp::Add => numeric_op(left, right, |a, b| a + b, |a, b| a + b),
315 BinaryOp::Sub => numeric_op(left, right, |a, b| a - b, |a, b| a - b),
316 BinaryOp::Mul => numeric_op(left, right, |a, b| a * b, |a, b| a * b),
317 BinaryOp::Div => numeric_op(left, right, |a, b| a / b, |a, b| a / b),
318 BinaryOp::Mod => numeric_op(left, right, |a, b| a % b, |a, b| a % b),
319 BinaryOp::Pow => {
320 let l = left.as_f64().unwrap_or(0.0);
321 let r = right.as_f64().unwrap_or(0.0);
322 Ok(Value::Float(l.powf(r)))
323 }
324 BinaryOp::Eq => Ok(Value::Bool(values_equal(left, right))),
325 BinaryOp::NotEq => Ok(Value::Bool(!values_equal(left, right))),
326 BinaryOp::Lt => Ok(Value::Bool(value_less_than(left, right))),
327 BinaryOp::LtEq => Ok(Value::Bool(
328 value_less_than(left, right) || values_equal(left, right),
329 )),
330 BinaryOp::Gt => Ok(Value::Bool(value_less_than(right, left))),
331 BinaryOp::GtEq => Ok(Value::Bool(
332 value_less_than(right, left) || values_equal(left, right),
333 )),
334 BinaryOp::And => match (left.as_bool(), right.as_bool()) {
335 (Some(a), Some(b)) => Ok(Value::Bool(a && b)),
336 _ => Ok(Value::Null),
337 },
338 BinaryOp::Or => match (left.as_bool(), right.as_bool()) {
339 (Some(a), Some(b)) => Ok(Value::Bool(a || b)),
340 _ => Ok(Value::Null),
341 },
342 BinaryOp::Xor => match (left.as_bool(), right.as_bool()) {
343 (Some(a), Some(b)) => Ok(Value::Bool(a ^ b)),
344 _ => Ok(Value::Null),
345 },
346 BinaryOp::Contains => match (left.as_str(), right.as_str()) {
347 (Some(l), Some(r)) => Ok(Value::Bool(l.contains(r))),
348 _ => Ok(Value::Null),
349 },
350 BinaryOp::StartsWith => match (left.as_str(), right.as_str()) {
351 (Some(l), Some(r)) => Ok(Value::Bool(l.starts_with(r))),
352 _ => Ok(Value::Null),
353 },
354 BinaryOp::EndsWith => match (left.as_str(), right.as_str()) {
355 (Some(l), Some(r)) => Ok(Value::Bool(l.ends_with(r))),
356 _ => Ok(Value::Null),
357 },
358 _ => Err(LocyError::EvaluationError {
359 message: format!("unsupported binary op in in-memory evaluation: {op:?}"),
360 }),
361 }
362}
363
364fn numeric_op(
365 left: &Value,
366 right: &Value,
367 int_op: impl Fn(i64, i64) -> i64,
368 float_op: impl Fn(f64, f64) -> f64,
369) -> Result<Value, LocyError> {
370 match (left, right) {
371 (Value::Int(a), Value::Int(b)) => Ok(Value::Int(int_op(*a, *b))),
372 _ => {
373 let a = left.as_f64().ok_or_else(|| LocyError::TypeError {
374 message: format!("numeric op requires number, got {left:?}"),
375 })?;
376 let b = right.as_f64().ok_or_else(|| LocyError::TypeError {
377 message: format!("numeric op requires number, got {right:?}"),
378 })?;
379 Ok(Value::Float(float_op(a, b)))
380 }
381 }
382}
383
384fn eval_function(name: &str, args: &[Value]) -> Result<Value, LocyError> {
385 let upper = name.to_uppercase();
386 match upper.as_str() {
387 "TOINTEGER" | "TOINT" => {
388 let v = args.first().unwrap_or(&Value::Null);
389 match v {
390 Value::Int(i) => Ok(Value::Int(*i)),
391 Value::Float(f) => Ok(Value::Int(*f as i64)),
392 Value::String(s) => {
393 s.parse::<i64>()
394 .map(Value::Int)
395 .map_err(|_| LocyError::TypeError {
396 message: format!("cannot convert '{s}' to integer"),
397 })
398 }
399 _ => Ok(Value::Null),
400 }
401 }
402 "TOFLOAT" => {
403 let v = args.first().unwrap_or(&Value::Null);
404 match v {
405 Value::Float(f) => Ok(Value::Float(*f)),
406 Value::Int(i) => Ok(Value::Float(*i as f64)),
407 Value::String(s) => {
408 s.parse::<f64>()
409 .map(Value::Float)
410 .map_err(|_| LocyError::TypeError {
411 message: format!("cannot convert '{s}' to float"),
412 })
413 }
414 _ => Ok(Value::Null),
415 }
416 }
417 "TOSTRING" => {
418 let v = args.first().unwrap_or(&Value::Null);
419 match v {
420 Value::String(s) => Ok(Value::String(s.clone())),
421 Value::Int(i) => Ok(Value::String(i.to_string())),
422 Value::Float(f) => Ok(Value::String(f.to_string())),
423 Value::Bool(b) => Ok(Value::String(b.to_string())),
424 Value::Null => Ok(Value::Null),
425 _ => Ok(Value::String(format!("{v:?}"))),
426 }
427 }
428 "ABS" => {
429 let v = args.first().unwrap_or(&Value::Null);
430 match v {
431 Value::Int(i) => Ok(Value::Int(i.abs())),
432 Value::Float(f) => Ok(Value::Float(f.abs())),
433 _ => Ok(Value::Null),
434 }
435 }
436 "COALESCE" => {
437 for a in args {
438 if !a.is_null() {
439 return Ok(a.clone());
440 }
441 }
442 Ok(Value::Null)
443 }
444 "SIMILAR_TO" | "VECTOR_SIMILARITY" => {
445 if args.len() < 2 {
446 return Err(LocyError::EvaluationError {
447 message: format!("{name} requires at least 2 arguments"),
448 });
449 }
450 crate::query::similar_to::eval_similar_to_pure(&args[0], &args[1]).map_err(|e| {
454 LocyError::EvaluationError {
455 message: e.to_string(),
456 }
457 })
458 }
459 _ => crate::query::expr_eval::eval_scalar_function(name, args, None).map_err(|e| {
464 LocyError::EvaluationError {
465 message: e.to_string(),
466 }
467 }),
468 }
469}
470
471pub fn values_equal(a: &Value, b: &Value) -> bool {
473 match (a, b) {
474 (Value::Int(x), Value::Float(y)) => (*x as f64) == *y,
475 (Value::Float(x), Value::Int(y)) => *x == (*y as f64),
476 _ => a == b,
477 }
478}
479
480pub fn values_equal_for_join(a: &Value, b: &Value) -> bool {
488 match (a, b) {
489 (Value::Node(na), Value::Node(nb)) => na.vid == nb.vid,
490 (Value::Edge(ea), Value::Edge(eb)) => ea.eid == eb.eid,
491 _ => values_equal(a, b),
492 }
493}
494
495pub fn value_cmp(a: &Value, b: &Value) -> std::cmp::Ordering {
497 if value_less_than(a, b) {
498 std::cmp::Ordering::Less
499 } else if value_less_than(b, a) {
500 std::cmp::Ordering::Greater
501 } else {
502 std::cmp::Ordering::Equal
503 }
504}
505
506pub fn value_less_than(a: &Value, b: &Value) -> bool {
508 match (a, b) {
509 (Value::Int(x), Value::Int(y)) => x < y,
510 (Value::Float(x), Value::Float(y)) => x < y,
511 (Value::Int(x), Value::Float(y)) => (*x as f64) < *y,
512 (Value::Float(x), Value::Int(y)) => *x < (*y as f64),
513 (Value::String(x), Value::String(y)) => x < y,
514 _ => false,
515 }
516}
517
518pub fn value_compare(a: &Value, b: &Value, null_last: bool) -> std::cmp::Ordering {
520 use std::cmp::Ordering;
521 let null_order = if null_last {
522 Ordering::Greater
523 } else {
524 Ordering::Less
525 };
526 match (a.is_null(), b.is_null()) {
527 (true, true) => Ordering::Equal,
528 (true, false) => null_order,
529 (false, true) => null_order.reverse(),
530 (false, false) => value_cmp(a, b),
531 }
532}
533
534pub fn record_batches_to_locy_rows(batches: &[RecordBatch]) -> Vec<FactRow> {
543 let mut rows = Vec::new();
544 for batch in batches {
545 let schema = batch.schema();
546 for row_idx in 0..batch.num_rows() {
547 let mut row = HashMap::new();
548 for (col_idx, field) in schema.fields().iter().enumerate() {
549 if field.name().starts_with("__feat_") {
554 continue;
555 }
556 let column = batch.column(col_idx);
557 let data_type = if uni_common::core::schema::is_datetime_struct(field.data_type()) {
558 Some(&uni_common::DataType::DateTime)
559 } else if uni_common::core::schema::is_time_struct(field.data_type()) {
560 Some(&uni_common::DataType::Time)
561 } else {
562 None
563 };
564 let value = uni_store::storage::arrow_convert::arrow_to_value(
565 column.as_ref(),
566 row_idx,
567 data_type,
568 );
569 row.insert(field.name().clone(), value);
570 }
571 normalize_graph_row(&mut row);
572 rows.push(row);
573 }
574 }
575 rows
576}
577
578pub(crate) fn normalize_graph_row(row: &mut FactRow) {
587 let entity_vars: Vec<String> = row
590 .keys()
591 .filter(|k| {
592 !k.contains('.')
593 && match row.get(*k) {
594 Some(Value::Map(m)) => m.contains_key("_vid") || m.contains_key("_eid"),
595 _ => false,
596 }
597 })
598 .cloned()
599 .collect();
600
601 for var in &entity_vars {
602 let prefix = format!("{}.", var);
605 let helper_keys: Vec<String> = row
606 .keys()
607 .filter(|k| k.starts_with(&prefix))
608 .cloned()
609 .collect();
610 for key in &helper_keys {
611 let prop_name = &key[prefix.len()..];
612 if let Some(val) = row.get(key).cloned()
613 && let Some(Value::Map(m)) = row.get_mut(var)
614 {
615 m.entry(prop_name.to_string()).or_insert(val);
616 }
617 }
618 for key in helper_keys {
620 row.remove(&key);
621 }
622
623 if let Some(Value::Map(map)) = row.remove(var) {
625 row.insert(var.clone(), map_to_graph_entity(map));
626 }
627 }
628}
629
630fn map_to_graph_entity(map: HashMap<String, Value>) -> Value {
632 use uni_common::core::id::{Eid, Vid};
633 use uni_common::value::{Edge, Node};
634
635 if let Some(eid_val) = map.get("_eid") {
637 let eid = match eid_val {
638 Value::Int(i) => Eid::new(*i as u64),
639 _ => return Value::Map(map),
640 };
641 let edge_type = match map.get("_type") {
642 Some(Value::String(s)) => s.clone(),
643 _ => String::new(),
644 };
645 let src = match map.get("_src_vid") {
646 Some(Value::Int(i)) => Vid::new(*i as u64),
647 _ => Vid::new(0),
648 };
649 let dst = match map.get("_dst_vid") {
650 Some(Value::Int(i)) => Vid::new(*i as u64),
651 _ => Vid::new(0),
652 };
653 let properties = extract_properties_from_map(&map);
654 return Value::Edge(Edge {
655 eid,
656 edge_type,
657 src,
658 dst,
659 properties,
660 });
661 }
662
663 if let Some(vid_val) = map.get("_vid") {
665 let vid = match vid_val {
666 Value::Int(i) => Vid::new(*i as u64),
667 _ => return Value::Map(map),
668 };
669 let labels = match map.get("_labels") {
670 Some(Value::List(list)) => list
671 .iter()
672 .filter_map(|v| match v {
673 Value::String(s) => Some(s.clone()),
674 _ => None,
675 })
676 .collect(),
677 _ => Vec::new(),
678 };
679 let properties = extract_properties_from_map(&map);
680 return Value::Node(Node {
681 vid,
682 labels,
683 properties,
684 });
685 }
686
687 Value::Map(map)
688}
689
690fn extract_properties_from_map(map: &HashMap<String, Value>) -> HashMap<String, Value> {
696 let mut properties = HashMap::new();
697
698 if let Some(Value::Map(all_props)) = map.get("_all_props") {
700 for (k, v) in all_props {
701 properties.insert(k.clone(), v.clone());
702 }
703 }
704
705 for (k, v) in map {
707 if !k.starts_with('_') && k != "properties" {
708 properties.entry(k.clone()).or_insert_with(|| v.clone());
709 }
710 }
711
712 properties
713}