1use std::fmt::Debug;
2use std::marker::PhantomData;
3
4use trellis_core::{
5 Graph, GraphResult, InputNode, InvariantResultTrace, OutputFrameTrace, ResourceCommandTrace,
6 Transaction, TransactionResult,
7};
8
9use crate::{
10 FullRecomputeOracle, OracleCheck, OracleMismatch, OutputLedger, ResourceLedger, Scenario,
11 ScenarioError, StageOperation, TransactionScript,
12};
13
14type InvariantCheck<G, C, O> = dyn Fn(&G, &TransactionResult<C, O>) -> bool + 'static;
15
16pub trait ScenarioTarget<C = (), O = ()> {
18 fn graph(&self) -> &Graph<C, O>;
20
21 fn graph_mut(&mut self) -> &mut Graph<C, O>;
23}
24
25impl<C, O> ScenarioTarget<C, O> for Graph<C, O> {
26 fn graph(&self) -> &Graph<C, O> {
27 self
28 }
29
30 fn graph_mut(&mut self) -> &mut Graph<C, O> {
31 self
32 }
33}
34
35pub struct TrellisHarness<G, C = (), O = ()> {
37 target: G,
38 scenario: Scenario,
39 resource_ledger: ResourceLedger<C>,
40 output_ledger: OutputLedger<O>,
41 _marker: PhantomData<fn() -> C>,
42}
43
44impl<G, C, O> TrellisHarness<G, C, O>
45where
46 G: ScenarioTarget<C, O>,
47 C: Clone + Debug + PartialEq,
48 O: Clone + Debug + PartialEq,
49{
50 pub fn new(build: impl FnOnce() -> G) -> Self {
52 Self::from_target(build())
53 }
54
55 pub fn from_target(target: G) -> Self {
57 Self {
58 target,
59 scenario: Scenario::new(),
60 resource_ledger: ResourceLedger::new(),
61 output_ledger: OutputLedger::new(),
62 _marker: PhantomData,
63 }
64 }
65
66 pub fn target(&self) -> &G {
68 &self.target
69 }
70
71 pub fn scenario(&self) -> &Scenario {
73 &self.scenario
74 }
75
76 pub fn resource_ledger(&self) -> &ResourceLedger<C> {
78 &self.resource_ledger
79 }
80
81 pub fn output_ledger(&self) -> &OutputLedger<O> {
83 &self.output_ledger
84 }
85
86 pub fn step(&mut self, name: impl Into<String>) -> HarnessStep<'_, G, C, O> {
88 HarnessStep {
89 harness: self,
90 name: name.into(),
91 operations: Vec::new(),
92 expected_resource_commands: None,
93 expected_output_frames: None,
94 invariant_checks: Vec::new(),
95 }
96 }
97
98 pub fn run_script(&mut self, script: &TransactionScript<C, O>) -> Result<(), ScenarioError> {
100 for step in script.steps() {
101 self.commit_operations(step.name(), &step.operations, &[], None, None)?;
102 }
103 Ok(())
104 }
105
106 pub fn replay(
108 build: impl FnOnce() -> G,
109 script: &TransactionScript<C, O>,
110 ) -> Result<Self, ScenarioError> {
111 let mut harness = Self::new(build);
112 harness.run_script(script)?;
113 Ok(harness)
114 }
115
116 pub fn assert_replay_matches(&self, other: &Self) -> Result<(), ScenarioError> {
118 self.scenario.assert_replay_matches(&other.scenario)?;
119 let expected = self.final_state_debug_dump();
120 let actual = other.final_state_debug_dump();
121 if expected != actual {
122 return Err(ScenarioError::ReplayFinalStateMismatch { expected, actual });
123 }
124 assert_equal_debug(
125 "resource_command_records",
126 self.resource_ledger.command_records(),
127 other.resource_ledger.command_records(),
128 )?;
129 assert_equal_debug(
130 "output_frame_records",
131 self.output_ledger.frame_records(),
132 other.output_ledger.frame_records(),
133 )?;
134 assert_equal_debug(
135 "resource_ledger_snapshots",
136 &self.resource_ledger,
137 &other.resource_ledger,
138 )?;
139 assert_equal_debug(
140 "output_ledger_snapshots",
141 &self.output_ledger,
142 &other.output_ledger,
143 )?;
144 Ok(())
145 }
146
147 pub fn final_state_debug_dump(&self) -> String {
149 self.target.graph().debug_dump()
150 }
151
152 pub fn assert_oracle<Oracle>(
154 &self,
155 inputs: &Oracle::CanonicalInputs,
156 ) -> Result<OracleCheck<Oracle::ExpectedState>, OracleMismatch<Oracle::ExpectedState>>
157 where
158 Oracle: FullRecomputeOracle<G>,
159 {
160 crate::assert_incremental_equals_full::<G, Oracle>(&self.target, inputs)
161 }
162
163 fn commit_operations(
164 &mut self,
165 name: &str,
166 operations: &[Box<StageOperation<C, O>>],
167 invariant_checks: &[NamedInvariantCheck<G, C, O>],
168 expected_resource_commands: Option<&[ResourceCommandTrace]>,
169 expected_output_frames: Option<&[OutputFrameTrace]>,
170 ) -> Result<(), ScenarioError> {
171 let result = {
172 let graph = self.target.graph_mut();
173 let mut tx = graph
174 .begin_transaction()
175 .map_err(|error| step_commit_failed(name, error))?;
176 for operation in operations {
177 operation(&mut tx).map_err(|error| step_commit_failed(name, error))?;
178 }
179 tx.commit()
180 .map_err(|error| step_commit_failed(name, error))?
181 };
182
183 let mut trace = result.trace();
184 for check in invariant_checks {
185 let passed = (check.check)(&self.target, &result);
186 trace.invariant_results.push(InvariantResultTrace {
187 name: check.name.clone(),
188 passed,
189 });
190 if !passed {
191 return Err(ScenarioError::InvariantFailed {
192 step: name.to_owned(),
193 invariant: check.name.clone(),
194 transaction_id: result.transaction_id,
195 revision: result.revision,
196 });
197 }
198 }
199
200 self.resource_ledger.apply_result(&result);
201 self.output_ledger.apply_result(&result);
202 self.scenario.record_trace(name, trace);
203
204 if let Some(expected) = expected_resource_commands {
205 self.scenario
206 .assert_step_resource_commands(name, expected)?;
207 }
208 if let Some(expected) = expected_output_frames {
209 self.scenario.assert_step_output_frames(name, expected)?;
210 }
211 Ok(())
212 }
213}
214
215pub struct HarnessStep<'harness, G, C, O> {
217 harness: &'harness mut TrellisHarness<G, C, O>,
218 name: String,
219 operations: Vec<Box<StageOperation<C, O>>>,
220 expected_resource_commands: Option<Vec<ResourceCommandTrace>>,
221 expected_output_frames: Option<Vec<OutputFrameTrace>>,
222 invariant_checks: Vec<NamedInvariantCheck<G, C, O>>,
223}
224
225impl<'harness, G, C, O> HarnessStep<'harness, G, C, O>
226where
227 G: ScenarioTarget<C, O>,
228 C: Clone + Debug + PartialEq,
229 O: Clone + Debug + PartialEq,
230{
231 pub fn input<T>(mut self, input: InputNode<T>, value: T) -> Self
233 where
234 T: Clone + PartialEq + 'static,
235 {
236 self.operations
237 .push(Box::new(move |tx| tx.set_input(input, value.clone())));
238 self
239 }
240
241 pub fn operation(
243 mut self,
244 operation: impl for<'tx> Fn(&mut Transaction<'tx, C, O>) -> GraphResult<()> + 'static,
245 ) -> Self {
246 self.operations.push(Box::new(operation));
247 self
248 }
249
250 pub fn expect_plan(mut self, command: ResourceCommandTrace) -> Self {
252 self.expected_resource_commands
253 .get_or_insert_with(Vec::new)
254 .push(command);
255 self
256 }
257
258 pub fn expect_plans(
260 mut self,
261 commands: impl IntoIterator<Item = ResourceCommandTrace>,
262 ) -> Self {
263 self.expected_resource_commands = Some(commands.into_iter().collect());
264 self
265 }
266
267 pub fn expect_output(mut self, frame: OutputFrameTrace) -> Self {
269 self.expected_output_frames
270 .get_or_insert_with(Vec::new)
271 .push(frame);
272 self
273 }
274
275 pub fn expect_outputs(mut self, frames: impl IntoIterator<Item = OutputFrameTrace>) -> Self {
277 self.expected_output_frames = Some(frames.into_iter().collect());
278 self
279 }
280
281 pub fn check(
283 mut self,
284 name: impl Into<String>,
285 check: impl Fn(&G, &TransactionResult<C, O>) -> bool + 'static,
286 ) -> Self {
287 self.invariant_checks.push(NamedInvariantCheck {
288 name: name.into(),
289 check: Box::new(check),
290 });
291 self
292 }
293
294 pub fn commit(self) -> Result<&'harness mut TrellisHarness<G, C, O>, ScenarioError> {
296 self.harness.commit_operations(
297 &self.name,
298 &self.operations,
299 &self.invariant_checks,
300 self.expected_resource_commands.as_deref(),
301 self.expected_output_frames.as_deref(),
302 )?;
303 Ok(self.harness)
304 }
305}
306
307struct NamedInvariantCheck<G, C, O> {
308 name: String,
309 check: Box<InvariantCheck<G, C, O>>,
310}
311
312fn step_commit_failed(step: &str, error: trellis_core::GraphError) -> ScenarioError {
313 ScenarioError::StepCommitFailed {
314 step: step.to_owned(),
315 error,
316 }
317}
318
319fn assert_equal_debug<T>(field: &'static str, expected: &T, actual: &T) -> Result<(), ScenarioError>
320where
321 T: Debug + PartialEq + ?Sized,
322{
323 if expected == actual {
324 Ok(())
325 } else {
326 Err(ScenarioError::ReplayLedgerMismatch {
327 field,
328 expected: format!("{expected:#?}"),
329 actual: format!("{actual:#?}"),
330 })
331 }
332}