1use std::ops::AddAssign;
2
3use air::AIR;
4use air::table::TableId;
5use air::table::hash::HashTable;
6use air::table::hash::PermutationTrace;
7use air::table::op_stack;
8use air::table::processor;
9use air::table::ram;
10use air::table_column::HashMainColumn::CI;
11use air::table_column::MasterMainColumn;
12use arbitrary::Arbitrary;
13use indexmap::IndexMap;
14use indexmap::map::Entry::Occupied;
15use indexmap::map::Entry::Vacant;
16use isa::error::InstructionError;
17use isa::error::InstructionError::InstructionPointerOverflow;
18use isa::instruction::Instruction;
19use isa::program::Program;
20use itertools::Itertools;
21use ndarray::Array2;
22use ndarray::s;
23use strum::EnumCount;
24use strum::IntoEnumIterator;
25use twenty_first::prelude::*;
26
27use crate::ndarray_helper::ROW_AXIS;
28use crate::table;
29use crate::table::op_stack::OpStackTableEntry;
30use crate::table::ram::RamTableCall;
31use crate::table::u32::U32TableEntry;
32use crate::vm::CoProcessorCall;
33use crate::vm::VMState;
34
35#[derive(Debug, Clone)]
41pub struct AlgebraicExecutionTrace {
42 pub program: Program,
44
45 pub instruction_multiplicities: Vec<u32>,
52
53 pub processor_trace: Array2<BFieldElement>,
55
56 pub op_stack_underflow_trace: Array2<BFieldElement>,
57
58 pub ram_trace: Array2<BFieldElement>,
59
60 pub program_hash_trace: Array2<BFieldElement>,
66
67 pub hash_trace: Array2<BFieldElement>,
70
71 pub sponge_trace: Array2<BFieldElement>,
75
76 pub u32_entries: IndexMap<U32TableEntry, u64>,
84
85 pub cascade_table_lookup_multiplicities: IndexMap<u16, u64>,
89
90 pub lookup_table_lookup_multiplicities: [u64; AlgebraicExecutionTrace::LOOKUP_TABLE_HEIGHT],
92}
93
94#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
95pub struct TableHeight {
96 pub table: TableId,
97 pub height: usize,
98}
99
100impl AlgebraicExecutionTrace {
101 pub(crate) const LOOKUP_TABLE_HEIGHT: usize = 1 << 8;
102
103 pub fn new(program: Program) -> Self {
104 const PROCESSOR_WIDTH: usize = <processor::ProcessorTable as AIR>::MainColumn::COUNT;
105 const OP_STACK_WIDTH: usize = <op_stack::OpStackTable as AIR>::MainColumn::COUNT;
106 const RAM_WIDTH: usize = <ram::RamTable as AIR>::MainColumn::COUNT;
107 const HASH_WIDTH: usize = <HashTable as AIR>::MainColumn::COUNT;
108
109 let program_len = program.len_bwords();
110
111 let mut aet = Self {
112 program,
113 instruction_multiplicities: vec![0_u32; program_len],
114 processor_trace: Array2::default([0, PROCESSOR_WIDTH]),
115 op_stack_underflow_trace: Array2::default([0, OP_STACK_WIDTH]),
116 ram_trace: Array2::default([0, RAM_WIDTH]),
117 program_hash_trace: Array2::default([0, HASH_WIDTH]),
118 hash_trace: Array2::default([0, HASH_WIDTH]),
119 sponge_trace: Array2::default([0, HASH_WIDTH]),
120 u32_entries: IndexMap::new(),
121 cascade_table_lookup_multiplicities: IndexMap::new(),
122 lookup_table_lookup_multiplicities: [0; Self::LOOKUP_TABLE_HEIGHT],
123 };
124 aet.fill_program_hash_trace();
125 aet
126 }
127
128 pub fn padded_height(&self) -> usize {
134 self.height().height.next_power_of_two()
135 }
136
137 pub fn height(&self) -> TableHeight {
142 TableId::iter()
143 .map(|t| TableHeight::new(t, self.height_of_table(t)))
144 .max()
145 .unwrap()
146 }
147
148 pub fn height_of_table(&self, table: TableId) -> usize {
149 let hash_table_height = || {
150 self.sponge_trace.nrows() + self.hash_trace.nrows() + self.program_hash_trace.nrows()
151 };
152
153 match table {
154 TableId::Program => Self::padded_program_length(&self.program),
155 TableId::Processor => self.processor_trace.nrows(),
156 TableId::OpStack => self.op_stack_underflow_trace.nrows(),
157 TableId::Ram => self.ram_trace.nrows(),
158 TableId::JumpStack => self.processor_trace.nrows(),
159 TableId::Hash => hash_table_height(),
160 TableId::Cascade => self.cascade_table_lookup_multiplicities.len(),
161 TableId::Lookup => Self::LOOKUP_TABLE_HEIGHT,
162 TableId::U32 => self.u32_table_height(),
163 }
164 }
165
166 fn u32_table_height(&self) -> usize {
171 let entry_len = U32TableEntry::table_height_contribution;
172 let height = self.u32_entries.keys().map(entry_len).sum::<u32>();
173 height.try_into().unwrap()
174 }
175
176 fn padded_program_length(program: &Program) -> usize {
177 (program.len_bwords() + 1).next_multiple_of(Tip5::RATE)
183 }
184
185 fn fill_program_hash_trace(&mut self) {
188 let padded_program = Self::hash_input_pad_program(&self.program);
189 let mut program_sponge = Tip5::init();
190 for chunk_to_absorb in padded_program.chunks(Tip5::RATE) {
191 program_sponge.state[..Tip5::RATE]
192 .iter_mut()
193 .zip_eq(chunk_to_absorb)
194 .for_each(|(sponge_state_elem, &absorb_elem)| *sponge_state_elem = absorb_elem);
195 let hash_trace = program_sponge.trace();
196 let trace_addendum = table::hash::trace_to_table_rows(hash_trace);
197
198 self.increase_lookup_multiplicities(hash_trace);
199 self.program_hash_trace
200 .append(ROW_AXIS, trace_addendum.view())
201 .expect("shapes must be identical");
202 }
203
204 let instruction_column_index = CI.main_index();
205 let mut instruction_column = self.program_hash_trace.column_mut(instruction_column_index);
206 instruction_column.fill(Instruction::Hash.opcode_b());
207
208 let program_digest = program_sponge.state[..Digest::LEN].try_into().unwrap();
210 let program_digest = Digest::new(program_digest);
211 let expected_digest = self.program.hash();
212 assert_eq!(expected_digest, program_digest);
213 }
214
215 fn hash_input_pad_program(program: &Program) -> Vec<BFieldElement> {
216 let padded_program_length = Self::padded_program_length(program);
217
218 let program_iter = program.to_bwords().into_iter();
220 let one = bfe_array![1];
221 let zeros = bfe_array![0; tip5::RATE];
222 program_iter
223 .chain(one)
224 .chain(zeros)
225 .take(padded_program_length)
226 .collect()
227 }
228
229 pub(crate) fn record_state(&mut self, state: &VMState) -> Result<(), InstructionError> {
230 self.record_instruction_lookup(state.instruction_pointer)?;
231 self.append_state_to_processor_trace(state);
232 Ok(())
233 }
234
235 fn record_instruction_lookup(
236 &mut self,
237 instruction_pointer: usize,
238 ) -> Result<(), InstructionError> {
239 if instruction_pointer >= self.instruction_multiplicities.len() {
240 return Err(InstructionPointerOverflow);
241 }
242 self.instruction_multiplicities[instruction_pointer] += 1;
243 Ok(())
244 }
245
246 fn append_state_to_processor_trace(&mut self, state: &VMState) {
247 self.processor_trace
248 .push_row(state.to_processor_row().view())
249 .unwrap()
250 }
251
252 pub(crate) fn record_co_processor_call(&mut self, co_processor_call: CoProcessorCall) {
253 match co_processor_call {
254 CoProcessorCall::Tip5Trace(Instruction::Hash, trace) => self.append_hash_trace(*trace),
255 CoProcessorCall::SpongeStateReset => self.append_initial_sponge_state(),
256 CoProcessorCall::Tip5Trace(instruction, trace) => {
257 self.append_sponge_trace(instruction, *trace)
258 }
259 CoProcessorCall::U32(u32_entry) => self.record_u32_table_entry(u32_entry),
260 CoProcessorCall::OpStack(op_stack_entry) => self.record_op_stack_entry(op_stack_entry),
261 CoProcessorCall::Ram(ram_call) => self.record_ram_call(ram_call),
262 }
263 }
264
265 fn append_hash_trace(&mut self, trace: PermutationTrace) {
266 self.increase_lookup_multiplicities(trace);
267 let mut hash_trace_addendum = table::hash::trace_to_table_rows(trace);
268 hash_trace_addendum
269 .slice_mut(s![.., CI.main_index()])
270 .fill(Instruction::Hash.opcode_b());
271 self.hash_trace
272 .append(ROW_AXIS, hash_trace_addendum.view())
273 .expect("shapes must be identical");
274 }
275
276 fn append_initial_sponge_state(&mut self) {
277 let round_number = 0;
278 let initial_state = Tip5::init().state;
279 let mut hash_table_row = table::hash::trace_row_to_table_row(initial_state, round_number);
280 hash_table_row[CI.main_index()] = Instruction::SpongeInit.opcode_b();
281 self.sponge_trace.push_row(hash_table_row.view()).unwrap();
282 }
283
284 fn append_sponge_trace(&mut self, instruction: Instruction, trace: PermutationTrace) {
285 assert!(matches!(
286 instruction,
287 Instruction::SpongeAbsorb | Instruction::SpongeSqueeze
288 ));
289 self.increase_lookup_multiplicities(trace);
290 let mut sponge_trace_addendum = table::hash::trace_to_table_rows(trace);
291 sponge_trace_addendum
292 .slice_mut(s![.., CI.main_index()])
293 .fill(instruction.opcode_b());
294 self.sponge_trace
295 .append(ROW_AXIS, sponge_trace_addendum.view())
296 .expect("shapes must be identical");
297 }
298
299 fn increase_lookup_multiplicities(&mut self, trace: PermutationTrace) {
306 let rows_for_which_lookups_are_performed = trace.iter().dropping_back(1);
309 for row in rows_for_which_lookups_are_performed {
310 self.increase_lookup_multiplicities_for_row(row);
311 }
312 }
313
314 fn increase_lookup_multiplicities_for_row(&mut self, row: &[BFieldElement; tip5::STATE_SIZE]) {
318 for &state_element in &row[0..tip5::NUM_SPLIT_AND_LOOKUP] {
319 self.increase_lookup_multiplicities_for_state_element(state_element);
320 }
321 }
322
323 fn increase_lookup_multiplicities_for_state_element(&mut self, state_element: BFieldElement) {
326 for limb in table::hash::base_field_element_into_16_bit_limbs(state_element) {
327 match self.cascade_table_lookup_multiplicities.entry(limb) {
328 Occupied(mut cascade_table_entry) => *cascade_table_entry.get_mut() += 1,
329 Vacant(cascade_table_entry) => {
330 cascade_table_entry.insert(1);
331 self.increase_lookup_table_multiplicities_for_limb(limb);
332 }
333 }
334 }
335 }
336
337 fn increase_lookup_table_multiplicities_for_limb(&mut self, limb: u16) {
340 let limb_lo = limb & 0xff;
341 let limb_hi = (limb >> 8) & 0xff;
342 self.lookup_table_lookup_multiplicities[limb_lo as usize] += 1;
343 self.lookup_table_lookup_multiplicities[limb_hi as usize] += 1;
344 }
345
346 fn record_u32_table_entry(&mut self, u32_entry: U32TableEntry) {
347 self.u32_entries.entry(u32_entry).or_insert(0).add_assign(1)
348 }
349
350 fn record_op_stack_entry(&mut self, op_stack_entry: OpStackTableEntry) {
351 let op_stack_table_row = op_stack_entry.to_main_table_row();
352 self.op_stack_underflow_trace
353 .push_row(op_stack_table_row.view())
354 .unwrap();
355 }
356
357 fn record_ram_call(&mut self, ram_call: RamTableCall) {
358 self.ram_trace
359 .push_row(ram_call.to_table_row().view())
360 .unwrap();
361 }
362}
363
364impl TableHeight {
365 fn new(table: TableId, height: usize) -> Self {
366 Self { table, height }
367 }
368}
369
370impl PartialOrd for TableHeight {
371 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
372 Some(self.cmp(other))
373 }
374}
375
376impl Ord for TableHeight {
377 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
378 self.height.cmp(&other.height)
379 }
380}
381
382#[cfg(test)]
383#[cfg_attr(coverage_nightly, coverage(off))]
384mod tests {
385 use assert2::assert;
386 use isa::triton_asm;
387 use isa::triton_program;
388
389 use super::*;
390 use crate::prelude::*;
391
392 #[test]
393 fn pad_program_requiring_no_padding_zeros() {
394 let eight_nops = triton_asm![nop; 8];
395 let program = triton_program!({&eight_nops} halt);
396 let padded_program = AlgebraicExecutionTrace::hash_input_pad_program(&program);
397
398 let expected = [program.to_bwords(), bfe_vec![1]].concat();
399 assert!(expected == padded_program);
400 }
401
402 #[test]
403 fn height_of_any_table_can_be_computed() {
404 let program = triton_program!(halt);
405 let (aet, _) =
406 VM::trace_execution(program, PublicInput::default(), NonDeterminism::default())
407 .unwrap();
408
409 let _ = aet.height();
410 for table in TableId::iter() {
411 let _ = aet.height_of_table(table);
412 }
413 }
414}