1use std::ops::AddAssign;
2
3use air::table::hash::HashTable;
4use air::table::hash::PermutationTrace;
5use air::table::op_stack;
6use air::table::processor;
7use air::table::ram;
8use air::table::TableId;
9use air::table_column::HashMainColumn::CI;
10use air::table_column::MasterMainColumn;
11use air::AIR;
12use arbitrary::Arbitrary;
13use indexmap::map::Entry::Occupied;
14use indexmap::map::Entry::Vacant;
15use indexmap::IndexMap;
16use isa::error::InstructionError;
17use isa::error::InstructionError::InstructionPointerOverflow;
18use isa::instruction::Instruction;
19use isa::program::Program;
20use itertools::Itertools;
21use ndarray::s;
22use ndarray::Array2;
23use ndarray::Axis;
24use strum::EnumCount;
25use strum::IntoEnumIterator;
26use twenty_first::prelude::*;
27
28use crate::table;
29use crate::table::op_stack::OpStackTableEntry;
30use crate::table::ram::RamTableCall;
31use crate::table::u32::U32TableEntry;
32use crate::vm::CoProcessorCall;
33use crate::vm::VMState;
34
35#[derive(Debug, Clone)]
40pub struct AlgebraicExecutionTrace {
41 pub program: Program,
43
44 pub instruction_multiplicities: Vec<u32>,
50
51 pub processor_trace: Array2<BFieldElement>,
53
54 pub op_stack_underflow_trace: Array2<BFieldElement>,
55
56 pub ram_trace: Array2<BFieldElement>,
57
58 pub program_hash_trace: Array2<BFieldElement>,
63
64 pub hash_trace: Array2<BFieldElement>,
67
68 pub sponge_trace: Array2<BFieldElement>,
72
73 pub u32_entries: IndexMap<U32TableEntry, u64>,
79
80 pub cascade_table_lookup_multiplicities: IndexMap<u16, u64>,
83
84 pub lookup_table_lookup_multiplicities: [u64; AlgebraicExecutionTrace::LOOKUP_TABLE_HEIGHT],
86}
87
88#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
89pub struct TableHeight {
90 pub table: TableId,
91 pub height: usize,
92}
93
94impl AlgebraicExecutionTrace {
95 pub(crate) const LOOKUP_TABLE_HEIGHT: usize = 1 << 8;
96
97 pub fn new(program: Program) -> Self {
98 const PROCESSOR_WIDTH: usize = <processor::ProcessorTable as AIR>::MainColumn::COUNT;
99 const OP_STACK_WIDTH: usize = <op_stack::OpStackTable as AIR>::MainColumn::COUNT;
100 const RAM_WIDTH: usize = <ram::RamTable as AIR>::MainColumn::COUNT;
101 const HASH_WIDTH: usize = <HashTable as AIR>::MainColumn::COUNT;
102
103 let program_len = program.len_bwords();
104
105 let mut aet = Self {
106 program,
107 instruction_multiplicities: vec![0_u32; program_len],
108 processor_trace: Array2::default([0, PROCESSOR_WIDTH]),
109 op_stack_underflow_trace: Array2::default([0, OP_STACK_WIDTH]),
110 ram_trace: Array2::default([0, RAM_WIDTH]),
111 program_hash_trace: Array2::default([0, HASH_WIDTH]),
112 hash_trace: Array2::default([0, HASH_WIDTH]),
113 sponge_trace: Array2::default([0, HASH_WIDTH]),
114 u32_entries: IndexMap::new(),
115 cascade_table_lookup_multiplicities: IndexMap::new(),
116 lookup_table_lookup_multiplicities: [0; Self::LOOKUP_TABLE_HEIGHT],
117 };
118 aet.fill_program_hash_trace();
119 aet
120 }
121
122 pub fn padded_height(&self) -> usize {
128 self.height().height.next_power_of_two()
129 }
130
131 pub fn height(&self) -> TableHeight {
136 TableId::iter()
137 .map(|t| TableHeight::new(t, self.height_of_table(t)))
138 .max()
139 .unwrap()
140 }
141
142 pub fn height_of_table(&self, table: TableId) -> usize {
143 let hash_table_height = || {
144 self.sponge_trace.nrows() + self.hash_trace.nrows() + self.program_hash_trace.nrows()
145 };
146
147 match table {
148 TableId::Program => Self::padded_program_length(&self.program),
149 TableId::Processor => self.processor_trace.nrows(),
150 TableId::OpStack => self.op_stack_underflow_trace.nrows(),
151 TableId::Ram => self.ram_trace.nrows(),
152 TableId::JumpStack => self.processor_trace.nrows(),
153 TableId::Hash => hash_table_height(),
154 TableId::Cascade => self.cascade_table_lookup_multiplicities.len(),
155 TableId::Lookup => Self::LOOKUP_TABLE_HEIGHT,
156 TableId::U32 => self.u32_table_height(),
157 }
158 }
159
160 fn u32_table_height(&self) -> usize {
165 let entry_len = U32TableEntry::table_height_contribution;
166 let height = self.u32_entries.keys().map(entry_len).sum::<u32>();
167 height.try_into().unwrap()
168 }
169
170 fn padded_program_length(program: &Program) -> usize {
171 (program.len_bwords() + 1).next_multiple_of(Tip5::RATE)
176 }
177
178 fn fill_program_hash_trace(&mut self) {
180 let padded_program = Self::hash_input_pad_program(&self.program);
181 let mut program_sponge = Tip5::init();
182 for chunk_to_absorb in padded_program.chunks(Tip5::RATE) {
183 program_sponge.state[..Tip5::RATE]
184 .iter_mut()
185 .zip_eq(chunk_to_absorb)
186 .for_each(|(sponge_state_elem, &absorb_elem)| *sponge_state_elem = absorb_elem);
187 let hash_trace = program_sponge.trace();
188 let trace_addendum = table::hash::trace_to_table_rows(hash_trace);
189
190 self.increase_lookup_multiplicities(hash_trace);
191 self.program_hash_trace
192 .append(Axis(0), trace_addendum.view())
193 .expect("shapes must be identical");
194 }
195
196 let instruction_column_index = CI.main_index();
197 let mut instruction_column = self.program_hash_trace.column_mut(instruction_column_index);
198 instruction_column.fill(Instruction::Hash.opcode_b());
199
200 let program_digest = program_sponge.state[..Digest::LEN].try_into().unwrap();
202 let program_digest = Digest::new(program_digest);
203 let expected_digest = self.program.hash();
204 assert_eq!(expected_digest, program_digest);
205 }
206
207 fn hash_input_pad_program(program: &Program) -> Vec<BFieldElement> {
208 let padded_program_length = Self::padded_program_length(program);
209
210 let program_iter = program.to_bwords().into_iter();
212 let one = bfe_array![1];
213 let zeros = bfe_array![0; tip5::RATE];
214 program_iter
215 .chain(one)
216 .chain(zeros)
217 .take(padded_program_length)
218 .collect()
219 }
220
221 pub(crate) fn record_state(&mut self, state: &VMState) -> Result<(), InstructionError> {
222 self.record_instruction_lookup(state.instruction_pointer)?;
223 self.append_state_to_processor_trace(state);
224 Ok(())
225 }
226
227 fn record_instruction_lookup(
228 &mut self,
229 instruction_pointer: usize,
230 ) -> Result<(), InstructionError> {
231 if instruction_pointer >= self.instruction_multiplicities.len() {
232 return Err(InstructionPointerOverflow);
233 }
234 self.instruction_multiplicities[instruction_pointer] += 1;
235 Ok(())
236 }
237
238 fn append_state_to_processor_trace(&mut self, state: &VMState) {
239 self.processor_trace
240 .push_row(state.to_processor_row().view())
241 .unwrap()
242 }
243
244 pub(crate) fn record_co_processor_call(&mut self, co_processor_call: CoProcessorCall) {
245 match co_processor_call {
246 CoProcessorCall::Tip5Trace(Instruction::Hash, trace) => self.append_hash_trace(*trace),
247 CoProcessorCall::SpongeStateReset => self.append_initial_sponge_state(),
248 CoProcessorCall::Tip5Trace(instruction, trace) => {
249 self.append_sponge_trace(instruction, *trace)
250 }
251 CoProcessorCall::U32(u32_entry) => self.record_u32_table_entry(u32_entry),
252 CoProcessorCall::OpStack(op_stack_entry) => self.record_op_stack_entry(op_stack_entry),
253 CoProcessorCall::Ram(ram_call) => self.record_ram_call(ram_call),
254 }
255 }
256
257 fn append_hash_trace(&mut self, trace: PermutationTrace) {
258 self.increase_lookup_multiplicities(trace);
259 let mut hash_trace_addendum = table::hash::trace_to_table_rows(trace);
260 hash_trace_addendum
261 .slice_mut(s![.., CI.main_index()])
262 .fill(Instruction::Hash.opcode_b());
263 self.hash_trace
264 .append(Axis(0), hash_trace_addendum.view())
265 .expect("shapes must be identical");
266 }
267
268 fn append_initial_sponge_state(&mut self) {
269 let round_number = 0;
270 let initial_state = Tip5::init().state;
271 let mut hash_table_row = table::hash::trace_row_to_table_row(initial_state, round_number);
272 hash_table_row[CI.main_index()] = Instruction::SpongeInit.opcode_b();
273 self.sponge_trace.push_row(hash_table_row.view()).unwrap();
274 }
275
276 fn append_sponge_trace(&mut self, instruction: Instruction, trace: PermutationTrace) {
277 assert!(matches!(
278 instruction,
279 Instruction::SpongeAbsorb | Instruction::SpongeSqueeze
280 ));
281 self.increase_lookup_multiplicities(trace);
282 let mut sponge_trace_addendum = table::hash::trace_to_table_rows(trace);
283 sponge_trace_addendum
284 .slice_mut(s![.., CI.main_index()])
285 .fill(instruction.opcode_b());
286 self.sponge_trace
287 .append(Axis(0), sponge_trace_addendum.view())
288 .expect("shapes must be identical");
289 }
290
291 fn increase_lookup_multiplicities(&mut self, trace: PermutationTrace) {
297 let rows_for_which_lookups_are_performed = trace.iter().dropping_back(1);
299 for row in rows_for_which_lookups_are_performed {
300 self.increase_lookup_multiplicities_for_row(row);
301 }
302 }
303
304 fn increase_lookup_multiplicities_for_row(&mut self, row: &[BFieldElement; tip5::STATE_SIZE]) {
307 for &state_element in &row[0..tip5::NUM_SPLIT_AND_LOOKUP] {
308 self.increase_lookup_multiplicities_for_state_element(state_element);
309 }
310 }
311
312 fn increase_lookup_multiplicities_for_state_element(&mut self, state_element: BFieldElement) {
315 for limb in table::hash::base_field_element_into_16_bit_limbs(state_element) {
316 match self.cascade_table_lookup_multiplicities.entry(limb) {
317 Occupied(mut cascade_table_entry) => *cascade_table_entry.get_mut() += 1,
318 Vacant(cascade_table_entry) => {
319 cascade_table_entry.insert(1);
320 self.increase_lookup_table_multiplicities_for_limb(limb);
321 }
322 }
323 }
324 }
325
326 fn increase_lookup_table_multiplicities_for_limb(&mut self, limb: u16) {
329 let limb_lo = limb & 0xff;
330 let limb_hi = (limb >> 8) & 0xff;
331 self.lookup_table_lookup_multiplicities[limb_lo as usize] += 1;
332 self.lookup_table_lookup_multiplicities[limb_hi as usize] += 1;
333 }
334
335 fn record_u32_table_entry(&mut self, u32_entry: U32TableEntry) {
336 self.u32_entries.entry(u32_entry).or_insert(0).add_assign(1)
337 }
338
339 fn record_op_stack_entry(&mut self, op_stack_entry: OpStackTableEntry) {
340 let op_stack_table_row = op_stack_entry.to_main_table_row();
341 self.op_stack_underflow_trace
342 .push_row(op_stack_table_row.view())
343 .unwrap();
344 }
345
346 fn record_ram_call(&mut self, ram_call: RamTableCall) {
347 self.ram_trace
348 .push_row(ram_call.to_table_row().view())
349 .unwrap();
350 }
351}
352
353impl TableHeight {
354 fn new(table: TableId, height: usize) -> Self {
355 Self { table, height }
356 }
357}
358
359impl PartialOrd for TableHeight {
360 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
361 Some(self.cmp(other))
362 }
363}
364
365impl Ord for TableHeight {
366 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
367 self.height.cmp(&other.height)
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use assert2::assert;
374 use isa::triton_asm;
375 use isa::triton_program;
376
377 use super::*;
378 use crate::prelude::*;
379
380 #[test]
381 fn pad_program_requiring_no_padding_zeros() {
382 let eight_nops = triton_asm![nop; 8];
383 let program = triton_program!({&eight_nops} halt);
384 let padded_program = AlgebraicExecutionTrace::hash_input_pad_program(&program);
385
386 let expected = [program.to_bwords(), bfe_vec![1]].concat();
387 assert!(expected == padded_program);
388 }
389
390 #[test]
391 fn height_of_any_table_can_be_computed() {
392 let program = triton_program!(halt);
393 let (aet, _) =
394 VM::trace_execution(program, PublicInput::default(), NonDeterminism::default())
395 .unwrap();
396
397 let _ = aet.height();
398 for table in TableId::iter() {
399 let _ = aet.height_of_table(table);
400 }
401 }
402}