1use std::collections::HashMap;
2use tidepool_eval::{Changed, Pass};
3use tidepool_repr::{Alt, AltCon, CoreExpr, CoreFrame, DataConId, Literal, PrimOpKind, VarId};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
7enum PartialValue {
8 Known(KnownValue),
10 Unknown,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
16enum KnownValue {
17 Lit(Literal),
19 Con(DataConId, Vec<KnownValue>),
21}
22
23type PartialEnv = HashMap<VarId, PartialValue>;
25
26pub struct PartialEval;
28
29impl Pass for PartialEval {
30 fn run(&self, expr: &mut CoreExpr) -> Changed {
31 if expr.nodes.is_empty() {
32 return false;
33 }
34 let mut new_nodes = Vec::new();
35 let (root_idx, _) = partial_eval_at(
36 expr,
37 expr.nodes.len() - 1,
38 &PartialEnv::new(),
39 &mut new_nodes,
40 );
41 let new_expr = CoreExpr { nodes: new_nodes }.extract_subtree(root_idx);
42 if new_expr != *expr {
43 *expr = new_expr;
44 true
45 } else {
46 false
47 }
48 }
49 fn name(&self) -> &str {
50 "PartialEval"
51 }
52}
53
54fn partial_eval_at(
56 expr: &CoreExpr,
57 idx: usize,
58 env: &PartialEnv,
59 new_nodes: &mut Vec<CoreFrame<usize>>,
60) -> (usize, PartialValue) {
61 match &expr.nodes[idx] {
62 CoreFrame::Var(v) => match env.get(v) {
63 Some(PartialValue::Known(kv)) => {
64 let ni = emit_known(kv, new_nodes);
65 (ni, PartialValue::Known(kv.clone()))
66 }
67 _ => {
68 let ni = new_nodes.len();
69 new_nodes.push(CoreFrame::Var(*v));
70 (ni, PartialValue::Unknown)
71 }
72 },
73 CoreFrame::Lit(lit) => {
74 let ni = new_nodes.len();
75 new_nodes.push(CoreFrame::Lit(lit.clone()));
76 (ni, PartialValue::Known(KnownValue::Lit(lit.clone())))
77 }
78 CoreFrame::Con { tag, fields } => {
79 let mut fi = Vec::new();
80 let mut fv = Vec::new();
81 for &f in fields {
82 let (i, v) = partial_eval_at(expr, f, env, new_nodes);
83 fi.push(i);
84 fv.push(v);
85 }
86 let ni = new_nodes.len();
87 new_nodes.push(CoreFrame::Con {
88 tag: *tag,
89 fields: fi,
90 });
91 let mut known_fields = Vec::new();
92 for v in fv {
93 if let PartialValue::Known(k) = v {
94 known_fields.push(k);
95 } else {
96 return (ni, PartialValue::Unknown);
97 }
98 }
99 (ni, PartialValue::Known(KnownValue::Con(*tag, known_fields)))
100 }
101 CoreFrame::LetNonRec { binder, rhs, body } => {
102 let (rhs_i, rhs_v) = partial_eval_at(expr, *rhs, env, new_nodes);
103 let mut new_env = env.clone();
104 new_env.insert(*binder, rhs_v.clone());
105 if matches!(rhs_v, PartialValue::Known(_)) {
106 partial_eval_at(expr, *body, &new_env, new_nodes)
108 } else {
109 let (body_i, body_v) = partial_eval_at(expr, *body, &new_env, new_nodes);
110 let ni = new_nodes.len();
111 new_nodes.push(CoreFrame::LetNonRec {
112 binder: *binder,
113 rhs: rhs_i,
114 body: body_i,
115 });
116 (ni, body_v)
117 }
118 }
119 CoreFrame::LetRec { bindings, body } => {
120 let mut new_env = env.clone();
121 for (b, _) in bindings {
122 new_env.insert(*b, PartialValue::Unknown);
123 }
124 let mut nb = Vec::new();
125 for (b, r) in bindings {
126 let (ri, _) = partial_eval_at(expr, *r, &new_env, new_nodes);
127 nb.push((*b, ri));
128 }
129 let (bi, bv) = partial_eval_at(expr, *body, &new_env, new_nodes);
130 let ni = new_nodes.len();
131 new_nodes.push(CoreFrame::LetRec {
132 bindings: nb,
133 body: bi,
134 });
135 (ni, bv)
136 }
137 CoreFrame::Case {
138 scrutinee,
139 binder,
140 alts,
141 } => {
142 let (si, sv) = partial_eval_at(expr, *scrutinee, env, new_nodes);
143 match &sv {
144 PartialValue::Known(KnownValue::Con(tag, field_vals)) => {
145 let matched = alts
146 .iter()
147 .find(|a| matches!(&a.con, AltCon::DataAlt(t) if t == tag))
148 .or_else(|| alts.iter().find(|a| matches!(&a.con, AltCon::Default)));
149 if let Some(alt) = matched {
150 let mut new_env = env.clone();
151 new_env.insert(*binder, sv.clone());
152 if let AltCon::DataAlt(_) = &alt.con {
153 for (b, fv) in alt.binders.iter().zip(field_vals.iter()) {
154 new_env.insert(*b, PartialValue::Known(fv.clone()));
155 }
156 }
157 partial_eval_at(expr, alt.body, &new_env, new_nodes)
158 } else {
159 emit_residual_case(expr, si, binder, alts, env, new_nodes)
160 }
161 }
162 PartialValue::Known(KnownValue::Lit(lit)) => {
163 let matched = alts
164 .iter()
165 .find(|a| matches!(&a.con, AltCon::LitAlt(l) if l == lit))
166 .or_else(|| alts.iter().find(|a| matches!(&a.con, AltCon::Default)));
167 if let Some(alt) = matched {
168 let mut new_env = env.clone();
169 new_env.insert(*binder, sv.clone());
170 partial_eval_at(expr, alt.body, &new_env, new_nodes)
171 } else {
172 emit_residual_case(expr, si, binder, alts, env, new_nodes)
173 }
174 }
175 PartialValue::Unknown => emit_residual_case(expr, si, binder, alts, env, new_nodes),
176 }
177 }
178 CoreFrame::PrimOp { op, args } => {
179 let mut ai = Vec::new();
180 let mut av = Vec::new();
181 for &a in args {
182 let (i, v) = partial_eval_at(expr, a, env, new_nodes);
183 ai.push(i);
184 av.push(v);
185 }
186 if let Some(result) = try_eval_primop(*op, &av) {
187 let ni = new_nodes.len();
188 new_nodes.push(CoreFrame::Lit(result.clone()));
189 (ni, PartialValue::Known(KnownValue::Lit(result)))
190 } else {
191 let ni = new_nodes.len();
192 new_nodes.push(CoreFrame::PrimOp { op: *op, args: ai });
193 (ni, PartialValue::Unknown)
194 }
195 }
196 CoreFrame::App { fun, arg } => {
197 let (fi, _) = partial_eval_at(expr, *fun, env, new_nodes);
198 let (ai, _) = partial_eval_at(expr, *arg, env, new_nodes);
199 let ni = new_nodes.len();
200 new_nodes.push(CoreFrame::App { fun: fi, arg: ai });
201 (ni, PartialValue::Unknown)
202 }
203 CoreFrame::Lam { binder, body } => {
204 let (bi, _) = partial_eval_at(expr, *body, env, new_nodes);
205 let ni = new_nodes.len();
206 new_nodes.push(CoreFrame::Lam {
207 binder: *binder,
208 body: bi,
209 });
210 (ni, PartialValue::Unknown)
211 }
212 CoreFrame::Join {
213 label,
214 params,
215 rhs,
216 body,
217 } => {
218 let (ri, _) = partial_eval_at(expr, *rhs, env, new_nodes);
219 let (bi, bv) = partial_eval_at(expr, *body, env, new_nodes);
220 let ni = new_nodes.len();
221 new_nodes.push(CoreFrame::Join {
222 label: *label,
223 params: params.clone(),
224 rhs: ri,
225 body: bi,
226 });
227 (ni, bv)
228 }
229 CoreFrame::Jump { label, args } => {
230 let mut ai = Vec::new();
231 for &a in args {
232 let (i, _) = partial_eval_at(expr, a, env, new_nodes);
233 ai.push(i);
234 }
235 let ni = new_nodes.len();
236 new_nodes.push(CoreFrame::Jump {
237 label: *label,
238 args: ai,
239 });
240 (ni, PartialValue::Unknown)
241 }
242 }
243}
244
245fn emit_known(kv: &KnownValue, new_nodes: &mut Vec<CoreFrame<usize>>) -> usize {
247 match kv {
248 KnownValue::Lit(lit) => {
249 let ni = new_nodes.len();
250 new_nodes.push(CoreFrame::Lit(lit.clone()));
251 ni
252 }
253 KnownValue::Con(tag, fields) => {
254 let fi: Vec<usize> = fields.iter().map(|k| emit_known(k, new_nodes)).collect();
255 let ni = new_nodes.len();
256 new_nodes.push(CoreFrame::Con {
257 tag: *tag,
258 fields: fi,
259 });
260 ni
261 }
262 }
263}
264
265fn emit_residual_case(
267 expr: &CoreExpr,
268 scrut_idx: usize,
269 binder: &VarId,
270 alts: &[Alt<usize>],
271 env: &PartialEnv,
272 new_nodes: &mut Vec<CoreFrame<usize>>,
273) -> (usize, PartialValue) {
274 let mut new_env = env.clone();
275 new_env.insert(*binder, PartialValue::Unknown);
276 let mut new_alts = Vec::new();
277 for alt in alts {
278 let mut alt_env = new_env.clone();
279 for b in &alt.binders {
280 alt_env.insert(*b, PartialValue::Unknown);
281 }
282 let (bi, _) = partial_eval_at(expr, alt.body, &alt_env, new_nodes);
283 new_alts.push(Alt {
284 con: alt.con.clone(),
285 binders: alt.binders.clone(),
286 body: bi,
287 });
288 }
289 let ni = new_nodes.len();
290 new_nodes.push(CoreFrame::Case {
291 scrutinee: scrut_idx,
292 binder: *binder,
293 alts: new_alts,
294 });
295 (ni, PartialValue::Unknown)
296}
297
298fn try_eval_primop(op: PrimOpKind, args: &[PartialValue]) -> Option<Literal> {
300 let lits: Vec<&Literal> = args
301 .iter()
302 .filter_map(|a| match a {
303 PartialValue::Known(KnownValue::Lit(l)) => Some(l),
304 _ => None,
305 })
306 .collect();
307 if lits.len() != args.len() {
308 return None;
309 }
310 match op {
311 PrimOpKind::IntAdd => {
312 if let [Literal::LitInt(a), Literal::LitInt(b)] = &lits[..] {
313 Some(Literal::LitInt(a.wrapping_add(*b)))
314 } else {
315 None
316 }
317 }
318 PrimOpKind::IntSub => {
319 if let [Literal::LitInt(a), Literal::LitInt(b)] = &lits[..] {
320 Some(Literal::LitInt(a.wrapping_sub(*b)))
321 } else {
322 None
323 }
324 }
325 PrimOpKind::IntMul => {
326 if let [Literal::LitInt(a), Literal::LitInt(b)] = &lits[..] {
327 Some(Literal::LitInt(a.wrapping_mul(*b)))
328 } else {
329 None
330 }
331 }
332 PrimOpKind::IntNegate => {
333 if let [Literal::LitInt(a)] = &lits[..] {
334 Some(Literal::LitInt(a.wrapping_neg()))
335 } else {
336 None
337 }
338 }
339 PrimOpKind::IntEq => int_cmp(&lits, |a, b| a == b),
340 PrimOpKind::IntNe => int_cmp(&lits, |a, b| a != b),
341 PrimOpKind::IntLt => int_cmp(&lits, |a, b| a < b),
342 PrimOpKind::IntLe => int_cmp(&lits, |a, b| a <= b),
343 PrimOpKind::IntGt => int_cmp(&lits, |a, b| a > b),
344 PrimOpKind::IntGe => int_cmp(&lits, |a, b| a >= b),
345 _ => None,
346 }
347}
348
349fn int_cmp(lits: &[&Literal], f: impl Fn(i64, i64) -> bool) -> Option<Literal> {
351 if let [Literal::LitInt(a), Literal::LitInt(b)] = lits {
352 Some(Literal::LitInt(if f(*a, *b) { 1 } else { 0 }))
353 } else {
354 None
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use tidepool_eval::env::Env;
362 use tidepool_eval::eval;
363 use tidepool_eval::heap::VecHeap;
364 use tidepool_eval::value::Value;
365 use tidepool_repr::{Alt, AltCon, CoreFrame, DataConId, Literal, PrimOpKind, VarId};
366
367 #[test]
368 fn test_partial_all_known() {
369 let nodes = vec![
371 CoreFrame::Lit(Literal::LitInt(1)), CoreFrame::Lit(Literal::LitInt(2)), CoreFrame::Var(VarId(1)), CoreFrame::Var(VarId(2)), CoreFrame::PrimOp {
376 op: PrimOpKind::IntAdd,
377 args: vec![2, 3],
378 }, CoreFrame::LetNonRec {
380 binder: VarId(2),
381 rhs: 1,
382 body: 4,
383 }, CoreFrame::LetNonRec {
385 binder: VarId(1),
386 rhs: 0,
387 body: 5,
388 }, ];
390 let mut expr = CoreExpr { nodes };
391 let pass = PartialEval;
392 pass.run(&mut expr);
393
394 assert_eq!(expr.nodes.len(), 1);
395 assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(3)));
396 }
397
398 #[test]
399 fn test_partial_all_unknown() {
400 let nodes = vec![CoreFrame::Var(VarId(1))];
402 let mut expr = CoreExpr { nodes };
403 let pass = PartialEval;
404 let changed = pass.run(&mut expr);
405
406 assert!(!changed);
407 assert_eq!(expr.nodes.len(), 1);
408 assert_eq!(expr.nodes[0], CoreFrame::Var(VarId(1)));
409 }
410
411 #[test]
412 fn test_partial_case_known_con() {
413 let nodes = vec![
415 CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::Con {
417 tag: DataConId(1),
418 fields: vec![0],
419 }, CoreFrame::Var(VarId(2)), CoreFrame::Case {
422 scrutinee: 1,
423 binder: VarId(3),
424 alts: vec![Alt {
425 con: AltCon::DataAlt(DataConId(1)),
426 binders: vec![VarId(2)],
427 body: 2,
428 }],
429 }, CoreFrame::LetNonRec {
431 binder: VarId(1),
432 rhs: 1,
433 body: 3,
434 }, ];
436 let mut expr = CoreExpr { nodes };
437 let pass = PartialEval;
438 pass.run(&mut expr);
439
440 assert_eq!(expr.nodes.len(), 1);
441 assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42)));
442 }
443
444 #[test]
445 fn test_partial_unknown_scrutinee() {
446 let nodes = vec![
448 CoreFrame::Var(VarId(1)), CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::Case {
451 scrutinee: 0,
452 binder: VarId(2),
453 alts: vec![Alt {
454 con: AltCon::Default,
455 binders: vec![],
456 body: 1,
457 }],
458 }, ];
460 let mut expr = CoreExpr { nodes };
461 let pass = PartialEval;
462 let changed = pass.run(&mut expr);
463
464 if changed {
468 assert!(matches!(expr.nodes.last().unwrap(), CoreFrame::Case { .. }));
469 }
470 }
471
472 #[test]
473 fn test_partial_primop_fold() {
474 let nodes = vec![
476 CoreFrame::Lit(Literal::LitInt(1)), CoreFrame::Lit(Literal::LitInt(2)), CoreFrame::PrimOp {
479 op: PrimOpKind::IntAdd,
480 args: vec![0, 1],
481 }, ];
483 let mut expr = CoreExpr { nodes };
484 let pass = PartialEval;
485 pass.run(&mut expr);
486
487 assert_eq!(expr.nodes.len(), 1);
488 assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(3)));
489 }
490
491 #[test]
492 fn test_partial_primop_unknown_arg() {
493 let nodes = vec![
495 CoreFrame::Lit(Literal::LitInt(1)), CoreFrame::Var(VarId(1)), CoreFrame::PrimOp {
498 op: PrimOpKind::IntAdd,
499 args: vec![0, 1],
500 }, ];
502 let mut expr = CoreExpr { nodes };
503 let pass = PartialEval;
504 pass.run(&mut expr);
505
506 assert!(matches!(
507 expr.nodes.last().unwrap(),
508 CoreFrame::PrimOp {
509 op: PrimOpKind::IntAdd,
510 ..
511 }
512 ));
513 }
514
515 #[test]
516 fn test_partial_preserves_eval() {
517 let nodes = vec![
519 CoreFrame::Lit(Literal::LitInt(10)), CoreFrame::Lit(Literal::LitInt(20)), CoreFrame::Var(VarId(1)), CoreFrame::Var(VarId(2)), CoreFrame::PrimOp {
524 op: PrimOpKind::IntAdd,
525 args: vec![2, 3],
526 }, CoreFrame::LetNonRec {
528 binder: VarId(2),
529 rhs: 1,
530 body: 4,
531 }, CoreFrame::LetNonRec {
533 binder: VarId(1),
534 rhs: 0,
535 body: 5,
536 }, ];
538 let mut expr = CoreExpr { nodes };
539
540 let mut heap_before = VecHeap::new();
541 let val_before = eval(&expr, &Env::new(), &mut heap_before).unwrap();
542
543 let pass = PartialEval;
544 pass.run(&mut expr);
545
546 let mut heap_after = VecHeap::new();
547 let val_after = eval(&expr, &Env::new(), &mut heap_after).unwrap();
548
549 if let (Value::Lit(Literal::LitInt(n1)), Value::Lit(Literal::LitInt(n2))) =
550 (val_before, val_after)
551 {
552 assert_eq!(n1, 30);
553 assert_eq!(n2, 30);
554 } else {
555 panic!("Expected LitInt(30)");
556 }
557 }
558
559 #[test]
560 fn test_partial_nested_let() {
561 let nodes = vec![
563 CoreFrame::Lit(Literal::LitInt(1)), CoreFrame::Var(VarId(1)), CoreFrame::Lit(Literal::LitInt(2)), CoreFrame::PrimOp {
567 op: PrimOpKind::IntAdd,
568 args: vec![1, 2],
569 }, CoreFrame::Var(VarId(2)), CoreFrame::Lit(Literal::LitInt(3)), CoreFrame::PrimOp {
573 op: PrimOpKind::IntAdd,
574 args: vec![4, 5],
575 }, CoreFrame::LetNonRec {
577 binder: VarId(2),
578 rhs: 3,
579 body: 6,
580 }, CoreFrame::LetNonRec {
582 binder: VarId(1),
583 rhs: 0,
584 body: 7,
585 }, ];
587 let mut expr = CoreExpr { nodes };
588 let pass = PartialEval;
589 pass.run(&mut expr);
590
591 assert_eq!(expr.nodes.len(), 1);
592 assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(6)));
593 }
594}