1use crate::ast::*;
2use crate::stdlib;
3use crate::{Env, Value};
4use thiserror::Error;
5
6#[derive(Debug, Error)]
7pub enum EvalError {
8 #[error("Unbound variable: {0}")]
9 UnboundVar(String),
10 #[error("Type error: {0}")]
11 TypeError(String),
12 #[error("Missing key: {0}")]
13 MissingKey(String),
14 #[error("Wrong shape: {0}")]
15 WrongShape(String),
16 #[error("Duplicate key: {0}")]
17 DuplicateKey(String),
18 #[error("Not callable: {0}")]
19 NotCallable(String),
20 #[error("Arity mismatch: expected {expected}, got {got}")]
21 ArityMismatch { expected: usize, got: usize },
22}
23
24fn fail_value(code: &str, msg: &str) -> Value {
25 Value::Fail_(code.to_string(), msg.to_string())
26}
27
28fn wrong_shape_value() -> Value {
29 fail_value("t_sda_wrong_shape", "wrong shape")
30}
31
32fn div_by_zero_value() -> Value {
33 fail_value("t_sda_div_by_zero", "division by zero")
34}
35
36fn unbound_name_value() -> Value {
37 fail_value("t_sda_unbound_name", "unbound name")
38}
39
40fn not_callable_value() -> Value {
41 fail_value("t_sda_not_callable", "not callable")
42}
43
44fn arity_mismatch_value() -> Value {
45 fail_value("t_sda_arity_mismatch", "arity mismatch")
46}
47
48pub(crate) fn ensure_comparable(value: &Value) -> Result<(), EvalError> {
49 match value {
50 Value::Null
51 | Value::Bool(_)
52 | Value::Num(_)
53 | Value::Str(_)
54 | Value::Bytes(_)
55 | Value::None_
56 | Value::Fail_(_, _) => Ok(()),
57 Value::Seq(items) | Value::Set(items) | Value::Bag(items) => {
58 for item in items {
59 ensure_comparable(item)?;
60 }
61 Ok(())
62 }
63 Value::Map(entries) | Value::Prod(entries) => {
64 for (_, value) in entries {
65 ensure_comparable(value)?;
66 }
67 Ok(())
68 }
69 Value::BagKV(pairs) => {
70 for (key, value) in pairs {
71 ensure_comparable(key)?;
72 ensure_comparable(value)?;
73 }
74 Ok(())
75 }
76 Value::Bind(key, value) => {
77 ensure_comparable(key)?;
78 ensure_comparable(value)
79 }
80 Value::Some_(inner) | Value::Ok_(inner) => ensure_comparable(inner),
81 Value::Lambda(_, _, _) => Err(EvalError::TypeError(
82 "function values are not comparable".to_string(),
83 )),
84 }
85}
86
87pub(crate) fn values_equal(a: &Value, b: &Value) -> bool {
88 a == b
89}
90
91pub fn eval_expr(expr: &Expr, env: &Env) -> Result<Value, EvalError> {
92 match expr {
93 Expr::Null => Ok(Value::Null),
94 Expr::Bool(b) => Ok(Value::Bool(*b)),
95 Expr::Num(n) => Ok(Value::Num(n.clone())),
96 Expr::Str(s) => Ok(Value::Str(s.clone())),
97 Expr::Bytes(bytes) => Ok(Value::Bytes(bytes.clone())),
98 Expr::Placeholder => Ok(env.get("_").cloned().unwrap_or_else(|| {
99 Value::Fail_(
100 "t_sda_unbound_placeholder".to_string(),
101 "unbound placeholder".to_string(),
102 )
103 })),
104 Expr::Ident(name) => env
105 .get(name)
106 .cloned()
107 .map_or_else(|| Ok(unbound_name_value()), Ok),
108 Expr::Seq(items) => {
109 let values: Result<Vec<Value>, EvalError> =
110 items.iter().map(|item| eval_expr(item, env)).collect();
111 Ok(Value::Seq(values?))
112 }
113 Expr::Set(items) => {
114 let mut values = Vec::new();
115 for item in items {
116 let value = eval_expr(item, env)?;
117 if ensure_comparable(&value).is_err() {
118 return Ok(wrong_shape_value());
119 }
120 if !values.iter().any(|existing| values_equal(existing, &value)) {
121 values.push(value);
122 }
123 }
124 Ok(Value::Set(values))
125 }
126 Expr::Bag(items) => {
127 let values: Result<Vec<Value>, EvalError> =
128 items.iter().map(|item| eval_expr(item, env)).collect();
129 Ok(Value::Bag(values?))
130 }
131 Expr::Map(entries) => {
132 let mut result = Vec::new();
133 for (k, v) in entries {
134 result.push((k.clone(), eval_expr(v, env)?));
135 }
136 Ok(Value::Map(result))
137 }
138 Expr::Prod(fields) => {
139 let mut result = Vec::new();
140 for (k, v) in fields {
141 result.push((k.clone(), eval_expr(v, env)?));
142 }
143 Ok(Value::Prod(result))
144 }
145 Expr::BagKV(entries) => {
146 let mut result = Vec::new();
147 for (k, v) in entries {
148 result.push((Value::Str(k.clone()), eval_expr(v, env)?));
149 }
150 Ok(Value::BagKV(result))
151 }
152 Expr::Some_(inner) => Ok(Value::Some_(Box::new(eval_expr(inner, env)?))),
153 Expr::None_ => Ok(Value::None_),
154 Expr::Ok_(inner) => Ok(Value::Ok_(Box::new(eval_expr(inner, env)?))),
155 Expr::Fail_(code_expr, msg_expr) => {
156 let code_value = eval_expr(code_expr, env)?;
157 let msg_value = eval_expr(msg_expr, env)?;
158 let code = match code_value {
159 Value::Str(s) => s,
160 other => format!("{other:?}"),
161 };
162 let msg = match msg_value {
163 Value::Str(s) => s,
164 other => format!("{other:?}"),
165 };
166 Ok(Value::Fail_(code, msg))
167 }
168 Expr::Lambda(param, body) => Ok(Value::Lambda(
169 param.clone(),
170 body.clone(),
171 Box::new(env.clone()),
172 )),
173 Expr::Call(func_expr, args) => {
174 let arg_vals: Result<Vec<Value>, EvalError> =
175 args.iter().map(|arg| eval_expr(arg, env)).collect();
176 let arg_vals = arg_vals?;
177
178 if let Expr::Ident(name) = func_expr.as_ref() {
179 if let Some(result) = stdlib::call_stdlib(name, arg_vals.clone()) {
180 return match result {
181 Err(EvalError::ArityMismatch { .. }) => Ok(arity_mismatch_value()),
182 other => other,
183 };
184 }
185 let func = if let Some(func) = env.get(name).cloned() {
186 func
187 } else {
188 return Ok(unbound_name_value());
189 };
190 return apply_lambda(func, arg_vals);
191 }
192
193 let func = eval_expr(func_expr, env)?;
194 apply_lambda(func, arg_vals)
195 }
196 Expr::Pipe(lhs, rhs) => {
197 let lhs_value = eval_expr(lhs, env)?;
198 let mut child_env = env.clone();
199 child_env.insert("_".to_string(), lhs_value);
200 eval_expr(rhs, &child_env)
201 }
202 Expr::Select(obj_expr, field, mode) => {
203 let obj = eval_expr(obj_expr, env)?;
204 eval_select(obj, field, mode)
205 }
206 Expr::UnOp(op, expr) => {
207 let value = eval_expr(expr, env)?;
208 match op {
209 UnOpKind::Neg => match value {
210 Value::Num(n) => Ok(Value::Num(n.neg())),
211 _ => Ok(wrong_shape_value()),
212 },
213 UnOpKind::Not => match value {
214 Value::Bool(b) => Ok(Value::Bool(!b)),
215 _ => Ok(wrong_shape_value()),
216 },
217 }
218 }
219 Expr::BinOp(op, lhs_expr, rhs_expr) => {
220 let lhs = eval_expr(lhs_expr, env)?;
221 let rhs = eval_expr(rhs_expr, env)?;
222 eval_binop(op, lhs, rhs)
223 }
224 Expr::Comprehension {
225 yield_expr,
226 binding,
227 collection,
228 pred,
229 } => {
230 enum Carrier {
231 Seq,
232 Set,
233 Bag,
234 }
235
236 let coll_val = eval_expr(collection, env)?;
237 let (items, carrier) = match coll_val {
238 Value::Seq(items) => (items, Carrier::Seq),
239 Value::Set(items) => (items, Carrier::Set),
240 Value::Bag(items) => (items, Carrier::Bag),
241 Value::BagKV(entries) => (
242 entries
243 .into_iter()
244 .map(|(key, value)| Value::Bind(Box::new(key), Box::new(value)))
245 .collect(),
246 Carrier::Bag,
247 ),
248 _ => return Ok(wrong_shape_value()),
249 };
250
251 let mut results = Vec::new();
252 for item in items {
253 let mut child_env = env.clone();
254 child_env.insert(binding.clone(), item.clone());
255
256 if let Some(pred_expr) = pred {
257 let pred_val = eval_expr(pred_expr, &child_env)?;
258 match pred_val {
259 Value::Bool(false) => continue,
260 Value::Bool(true) => {}
261 _ => return Ok(wrong_shape_value()),
262 }
263 }
264
265 let result = if let Some(yield_expr) = yield_expr {
266 eval_expr(yield_expr, &child_env)?
267 } else {
268 item
269 };
270 results.push(result);
271 }
272
273 match carrier {
274 Carrier::Seq => Ok(Value::Seq(results)),
275 Carrier::Bag => Ok(Value::Bag(results)),
276 Carrier::Set => {
277 let mut dedup = Vec::new();
278 for value in results {
279 if ensure_comparable(&value).is_err() {
280 return Ok(wrong_shape_value());
281 }
282 if !dedup.iter().any(|existing| values_equal(existing, &value)) {
283 dedup.push(value);
284 }
285 }
286 Ok(Value::Set(dedup))
287 }
288 }
289 }
290 }
291}
292
293fn eval_select(obj: Value, field: &str, mode: &SelectMode) -> Result<Value, EvalError> {
294 match &obj {
295 Value::Map(entries) => {
296 let found = entries.iter().find(|(k, _)| k == field).map(|(_, v)| v.clone());
297 match mode {
298 SelectMode::Plain => Ok(Value::Fail_(
299 "t_sda_wrong_shape".to_string(),
300 "wrong shape".to_string(),
301 )),
302 SelectMode::Optional => Ok(found
303 .map(|v| Value::Some_(Box::new(v)))
304 .unwrap_or(Value::None_)),
305 SelectMode::Required => Ok(found
306 .map(|v| Value::Ok_(Box::new(v)))
307 .unwrap_or_else(|| {
308 Value::Fail_(
309 "t_sda_missing_key".to_string(),
310 "missing key".to_string(),
311 )
312 })),
313 }
314 }
315 Value::Prod(fields) => {
316 let found = fields.iter().find(|(k, _)| k == field).map(|(_, v)| v.clone());
317 match mode {
318 SelectMode::Plain => Ok(found.unwrap_or_else(|| {
319 Value::Fail_(
320 "t_sda_unknown_field".to_string(),
321 "unknown field".to_string(),
322 )
323 })),
324 SelectMode::Optional | SelectMode::Required => Ok(Value::Fail_(
325 "t_sda_wrong_shape".to_string(),
326 "wrong shape".to_string(),
327 )),
328 }
329 }
330 Value::Bind(key, value) => {
331 let found = match field {
332 "key" => Some((**key).clone()),
333 "val" => Some((**value).clone()),
334 _ => None,
335 };
336 match mode {
337 SelectMode::Plain => Ok(found.unwrap_or(Value::Null)),
338 SelectMode::Optional => Ok(found
339 .map(|v| Value::Some_(Box::new(v)))
340 .unwrap_or(Value::None_)),
341 SelectMode::Required => Ok(found
342 .map(|v| Value::Ok_(Box::new(v)))
343 .unwrap_or_else(|| {
344 Value::Fail_(
345 "t_sda_missing_key".to_string(),
346 "missing key".to_string(),
347 )
348 })),
349 }
350 }
351 Value::BagKV(entries) => {
352 let matches: Vec<_> = entries
353 .iter()
354 .filter(|(k, _)| matches!(k, Value::Str(s) if s == field))
355 .collect();
356 match mode {
357 SelectMode::Plain => Ok(Value::Fail_(
358 "t_sda_wrong_shape".to_string(),
359 "wrong shape".to_string(),
360 )),
361 SelectMode::Optional => match matches.len() {
362 0 => Ok(Value::None_),
363 1 => Ok(Value::Some_(Box::new(matches[0].1.clone()))),
364 _ => Ok(Value::None_),
365 },
366 SelectMode::Required => match matches.len() {
367 0 => Ok(Value::Fail_(
368 "t_sda_missing_key".to_string(),
369 "missing key".to_string(),
370 )),
371 1 => Ok(Value::Ok_(Box::new(matches[0].1.clone()))),
372 _ => Ok(Value::Fail_(
373 "t_sda_duplicate_key".to_string(),
374 "duplicate key".to_string(),
375 )),
376 },
377 }
378 }
379 _ => match mode {
380 SelectMode::Optional => Ok(Value::Fail_(
381 "t_sda_wrong_shape".to_string(),
382 "wrong shape".to_string(),
383 )),
384 SelectMode::Required => Ok(Value::Fail_(
385 "t_sda_wrong_shape".to_string(),
386 "wrong shape".to_string(),
387 )),
388 SelectMode::Plain => Ok(Value::Fail_(
389 "t_sda_wrong_shape".to_string(),
390 "wrong shape".to_string(),
391 )),
392 },
393 }
394}
395
396fn eval_binop(op: &BinOpKind, lhs: Value, rhs: Value) -> Result<Value, EvalError> {
397 match op {
398 BinOpKind::Add => match (lhs, rhs) {
399 (Value::Num(a), Value::Num(b)) => Ok(Value::Num(a.add(&b))),
400 _ => Ok(wrong_shape_value()),
401 },
402 BinOpKind::Sub => match (lhs, rhs) {
403 (Value::Num(a), Value::Num(b)) => Ok(Value::Num(a.sub(&b))),
404 _ => Ok(wrong_shape_value()),
405 },
406 BinOpKind::Mul => match (lhs, rhs) {
407 (Value::Num(a), Value::Num(b)) => Ok(Value::Num(a.mul(&b))),
408 _ => Ok(wrong_shape_value()),
409 },
410 BinOpKind::Div => match (lhs, rhs) {
411 (Value::Num(a), Value::Num(b)) => {
412 if b.is_zero() {
413 Ok(div_by_zero_value())
414 } else {
415 Ok(Value::Num(a.div(&b)))
416 }
417 }
418 _ => Ok(wrong_shape_value()),
419 },
420 BinOpKind::Concat => match (lhs, rhs) {
421 (Value::Str(a), Value::Str(b)) => Ok(Value::Str(a + &b)),
422 (Value::Seq(mut a), Value::Seq(b)) => {
423 a.extend(b);
424 Ok(Value::Seq(a))
425 }
426 _ => Ok(wrong_shape_value()),
427 },
428 BinOpKind::Eq => {
429 if ensure_comparable(&lhs).is_err() || ensure_comparable(&rhs).is_err() {
430 return Ok(wrong_shape_value());
431 }
432 Ok(Value::Bool(values_equal(&lhs, &rhs)))
433 }
434 BinOpKind::Neq => {
435 if ensure_comparable(&lhs).is_err() || ensure_comparable(&rhs).is_err() {
436 return Ok(wrong_shape_value());
437 }
438 Ok(Value::Bool(!values_equal(&lhs, &rhs)))
439 }
440 BinOpKind::Lt => match (lhs, rhs) {
441 (Value::Num(a), Value::Num(b)) => Ok(Value::Bool(a < b)),
442 (Value::Str(a), Value::Str(b)) => Ok(Value::Bool(a < b)),
443 _ => Ok(wrong_shape_value()),
444 },
445 BinOpKind::Le => match (lhs, rhs) {
446 (Value::Num(a), Value::Num(b)) => Ok(Value::Bool(a <= b)),
447 (Value::Str(a), Value::Str(b)) => Ok(Value::Bool(a <= b)),
448 _ => Ok(wrong_shape_value()),
449 },
450 BinOpKind::Gt => match (lhs, rhs) {
451 (Value::Num(a), Value::Num(b)) => Ok(Value::Bool(a > b)),
452 (Value::Str(a), Value::Str(b)) => Ok(Value::Bool(a > b)),
453 _ => Ok(wrong_shape_value()),
454 },
455 BinOpKind::Ge => match (lhs, rhs) {
456 (Value::Num(a), Value::Num(b)) => Ok(Value::Bool(a >= b)),
457 (Value::Str(a), Value::Str(b)) => Ok(Value::Bool(a >= b)),
458 _ => Ok(wrong_shape_value()),
459 },
460 BinOpKind::And => match (lhs, rhs) {
461 (Value::Bool(a), Value::Bool(b)) => Ok(Value::Bool(a && b)),
462 _ => Ok(wrong_shape_value()),
463 },
464 BinOpKind::Or => match (lhs, rhs) {
465 (Value::Bool(a), Value::Bool(b)) => Ok(Value::Bool(a || b)),
466 _ => Ok(wrong_shape_value()),
467 },
468 BinOpKind::Union => match (lhs, rhs) {
469 (Value::Set(mut a), Value::Set(b)) => {
470 for item in b {
471 if !a.iter().any(|existing| values_equal(existing, &item)) {
472 a.push(item);
473 }
474 }
475 Ok(Value::Set(a))
476 }
477 _ => Ok(wrong_shape_value()),
478 },
479 BinOpKind::Inter => match (lhs, rhs) {
480 (Value::Set(a), Value::Set(b)) => {
481 let result = a
482 .into_iter()
483 .filter(|x| b.iter().any(|y| values_equal(x, y)))
484 .collect();
485 Ok(Value::Set(result))
486 }
487 _ => Ok(wrong_shape_value()),
488 },
489 BinOpKind::Diff => match (lhs, rhs) {
490 (Value::Set(a), Value::Set(b)) => {
491 let result = a
492 .into_iter()
493 .filter(|x| !b.iter().any(|y| values_equal(x, y)))
494 .collect();
495 Ok(Value::Set(result))
496 }
497 _ => Ok(wrong_shape_value()),
498 },
499 BinOpKind::BUnion => match (lhs, rhs) {
500 (Value::Bag(mut a), Value::Bag(b)) => {
501 a.extend(b);
502 Ok(Value::Bag(a))
503 }
504 _ => Ok(wrong_shape_value()),
505 },
506 BinOpKind::BDiff => match (lhs, rhs) {
507 (Value::Bag(a), Value::Bag(b)) => {
508 let mut remaining = b.clone();
509 let result = a
510 .into_iter()
511 .filter(|x| {
512 if let Some(idx) = remaining.iter().position(|y| values_equal(x, y)) {
513 remaining.remove(idx);
514 false
515 } else {
516 true
517 }
518 })
519 .collect();
520 Ok(Value::Bag(result))
521 }
522 _ => Ok(wrong_shape_value()),
523 },
524 BinOpKind::In => match rhs {
525 Value::Seq(items) => {
526 if ensure_comparable(&lhs).is_err() {
527 return Ok(wrong_shape_value());
528 }
529 for item in &items {
530 if ensure_comparable(item).is_err() {
531 return Ok(wrong_shape_value());
532 }
533 }
534 Ok(Value::Bool(items.iter().any(|x| values_equal(x, &lhs))))
535 }
536 Value::Set(items) => {
537 if ensure_comparable(&lhs).is_err() {
538 return Ok(wrong_shape_value());
539 }
540 for item in &items {
541 if ensure_comparable(item).is_err() {
542 return Ok(wrong_shape_value());
543 }
544 }
545 Ok(Value::Bool(items.iter().any(|x| values_equal(x, &lhs))))
546 }
547 Value::Bag(items) => {
548 if ensure_comparable(&lhs).is_err() {
549 return Ok(wrong_shape_value());
550 }
551 for item in &items {
552 if ensure_comparable(item).is_err() {
553 return Ok(wrong_shape_value());
554 }
555 }
556 Ok(Value::Bool(items.iter().any(|x| values_equal(x, &lhs))))
557 }
558 Value::Map(entries) => {
559 if let Value::Str(key) = &lhs {
560 Ok(Value::Bool(entries.iter().any(|(k, _)| k == key)))
561 } else {
562 Ok(wrong_shape_value())
563 }
564 }
565 Value::Prod(fields) => {
566 if let Value::Str(key) = &lhs {
567 Ok(Value::Bool(fields.iter().any(|(k, _)| k == key)))
568 } else {
569 Ok(wrong_shape_value())
570 }
571 }
572 _ => Ok(wrong_shape_value()),
573 },
574 }
575}
576
577pub(crate) fn apply_lambda(func: Value, args: Vec<Value>) -> Result<Value, EvalError> {
578 match func {
579 Value::Lambda(param, body, captured_env) => {
580 if args.len() != 1 {
581 return Ok(arity_mismatch_value());
582 }
583 let mut new_env = *captured_env;
584 new_env.insert(param, args.into_iter().next().unwrap());
585 eval_expr(&body, &new_env)
586 }
587 _ => Ok(not_callable_value()),
588 }
589}
590
591pub fn eval_program(program: &Program, env: &mut Env) -> Result<Option<Value>, EvalError> {
592 let mut last = None;
593 for stmt in &program.stmts {
594 match stmt {
595 Stmt::Let(name, expr) => {
596 let value = eval_expr(expr, env)?;
597 env.insert(name.clone(), value);
598 last = None;
599 }
600 Stmt::Expr(expr) => {
601 last = Some(eval_expr(expr, env)?);
602 }
603 }
604 }
605 Ok(last)
606}