vyre_reference/execution/
mod.rs1pub(crate) mod call;
8pub mod expr;
9pub(crate) mod expr_cast;
10pub(crate) mod hashmap;
11pub mod node;
12pub mod sequential;
13pub(crate) mod typed_ops;
14
15use std::borrow::Cow;
16
17use rustc_hash::FxHashMap;
18use vyre::ir::{InterpCtx, Node, NodeId, NodeStorage, Program, Value as IrValue};
19
20use crate::value::Value;
21
22pub(crate) fn program_for_interpreter(program: &Program) -> Result<Cow<'_, Program>, vyre::Error> {
30 if let Some(message) = program.top_level_region_violation() {
31 if program.entry().is_empty() {
32 return Err(vyre::Error::interp(format!(
33 "reference interpreter requires a top-level Region-wrapped Program: {message}"
34 )));
35 }
36 if matches!(program.entry().first(), Some(Node::Store { .. })) {
37 return Err(vyre::Error::interp(format!(
38 "reference interpreter requires a top-level Region-wrapped Program: {message}"
39 )));
40 }
41 return Ok(Cow::Owned(program.clone().reconcile_runnable_top_level()));
42 }
43 Ok(Cow::Borrowed(program))
44}
45
46pub fn reference_eval(program: &Program, inputs: &[Value]) -> Result<Vec<Value>, vyre::Error> {
52 run_arena_reference(program, inputs)
53}
54
55pub fn run_arena_reference(program: &Program, inputs: &[Value]) -> Result<Vec<Value>, vyre::Error> {
57 let program = program_for_interpreter(program)?;
58 hashmap::run_hashmap_reference(&program, inputs)
59}
60
61#[cfg(test)]
63pub fn eval_hashmap_reference(
64 program: &Program,
65 inputs: &[Value],
66) -> Result<Vec<Value>, vyre::Error> {
67 run_arena_reference(program, inputs)
68}
69
70pub fn run_storage_graph(
72 nodes: &[(NodeId, NodeStorage)],
73 outputs: &[NodeId],
74) -> Result<Vec<IrValue>, vyre::Error> {
75 let graph = nodes
76 .iter()
77 .map(|(id, node)| (*id, node))
78 .collect::<FxHashMap<_, _>>();
79 let mut ctx = InterpCtx::default();
80 let mut states = FxHashMap::with_capacity_and_hasher(graph.len(), Default::default());
81
82 for output in outputs {
83 eval_storage_node(*output, &graph, &mut ctx, &mut states)?;
84 }
85
86 outputs
87 .iter()
88 .map(|id| ctx.get(*id).map_err(interp_error))
89 .collect()
90}
91
92#[derive(Clone, Copy, Debug, PartialEq, Eq)]
93enum VisitState {
94 Visiting,
95 Done,
96}
97
98fn eval_storage_node(
99 id: NodeId,
100 graph: &FxHashMap<NodeId, &NodeStorage>,
101 ctx: &mut InterpCtx,
102 states: &mut FxHashMap<NodeId, VisitState>,
103) -> Result<(), vyre::Error> {
104 match states.get(&id).copied() {
105 Some(VisitState::Done) => return Ok(()),
106 Some(VisitState::Visiting) => return Err(cycle_error(id)),
107 None => {}
108 }
109
110 let node = *graph.get(&id).ok_or_else(|| missing_node_error(id))?;
111 states.insert(id, VisitState::Visiting);
112 let inputs = node.input_ids();
113 for input in &inputs {
114 eval_storage_node(*input, graph, ctx, states)?;
115 }
116 ctx.set_operands(inputs);
117 let value = node.interpret(ctx).map_err(interp_error)?;
118 ctx.set(id, value);
119 states.insert(id, VisitState::Done);
120 Ok(())
121}
122
123fn interp_error(error: vyre::ir::EvalError) -> vyre::Error {
124 vyre::Error::interp(error.to_string())
125}
126
127fn missing_node_error(id: NodeId) -> vyre::Error {
128 vyre::Error::interp(format!(
129 "graph references missing node {}. Fix: include every dependency in the interpreter input graph.",
130 id.0
131 ))
132}
133
134fn cycle_error(id: NodeId) -> vyre::Error {
135 vyre::Error::interp(format!(
136 "graph contains a dependency cycle at node {}. Fix: submit an acyclic dataflow graph.",
137 id.0
138 ))
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use vyre::ir::{BinOp, NodeStorage};
145
146 #[test]
147 fn generic_storage_graph_matches_recursive_oracle_for_10k_programs() {
148 let mut rng = 0x9e37_79b9_u64;
149 for case in 0..10_000 {
150 let graph = random_graph(&mut rng, case);
151 let output = graph.last().expect("Fix: generated graph is non-empty").0;
152 let expected =
153 recursive_value(output, &graph).expect("Fix: recursive oracle evaluates");
154 let actual = run_storage_graph(&graph, &[output])
155 .expect("Fix: generic graph interpreter evaluates")[0];
156 assert_eq!(actual, expected, "case {case}");
157 }
158 }
159
160 fn random_graph(rng: &mut u64, case: u32) -> Vec<(NodeId, NodeStorage)> {
161 let len = 2 + (next(rng) as usize % 31);
162 let mut graph = Vec::with_capacity(len);
163 graph.push((NodeId(0), NodeStorage::LitU32(case)));
164 graph.push((NodeId(1), NodeStorage::LitU32(next(rng))));
165 for index in 2..len {
166 let left = NodeId(next(rng) % index as u32);
167 let right = NodeId(next(rng) % index as u32);
168 let op = match next(rng) % 5 {
169 0 => BinOp::Add,
170 1 => BinOp::Sub,
171 2 => BinOp::Mul,
172 3 => BinOp::BitXor,
173 _ => BinOp::BitAnd,
174 };
175 graph.push((NodeId(index as u32), NodeStorage::BinOp { op, left, right }));
176 }
177 graph
178 }
179
180 fn recursive_value(
181 id: NodeId,
182 graph: &[(NodeId, NodeStorage)],
183 ) -> Result<IrValue, vyre::Error> {
184 let node = graph
185 .iter()
186 .find(|(node_id, _)| *node_id == id)
187 .map(|(_, node)| node)
188 .ok_or_else(|| missing_node_error(id))?;
189 match node {
190 NodeStorage::LitU32(value) => Ok(IrValue::U32(*value)),
191 NodeStorage::BinOp { op, left, right } => {
192 let left = expect_u32(recursive_value(*left, graph)?)?;
193 let right = expect_u32(recursive_value(*right, graph)?)?;
194 let value = match op {
195 BinOp::Add => left.wrapping_add(right),
196 BinOp::Sub => left.wrapping_sub(right),
197 BinOp::Mul => left.wrapping_mul(right),
198 BinOp::BitXor => left ^ right,
199 BinOp::BitAnd => left & right,
200 _ => {
201 return Err(vyre::Error::interp(
202 "recursive parity oracle received unsupported op. Fix: keep test generation within the oracle domain.",
203 ));
204 }
205 };
206 Ok(IrValue::U32(value))
207 }
208 _ => Err(vyre::Error::interp(
209 "recursive parity oracle received unsupported node. Fix: keep test generation within the oracle domain.",
210 )),
211 }
212 }
213
214 fn expect_u32(value: IrValue) -> Result<u32, vyre::Error> {
215 match value {
216 IrValue::U32(value) => Ok(value),
217 other => Err(vyre::Error::interp(format!(
218 "recursive parity oracle expected u32, got {other:?}. Fix: keep generated graphs scalar-u32 only."
219 ))),
220 }
221 }
222
223 fn next(rng: &mut u64) -> u32 {
224 *rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
225 (*rng >> 32) as u32
226 }
227}