triton_vm/
aet.rs

1use std::ops::AddAssign;
2
3use air::table::hash::HashTable;
4use air::table::hash::PermutationTrace;
5use air::table::op_stack;
6use air::table::processor;
7use air::table::ram;
8use air::table::TableId;
9use air::table_column::HashMainColumn::CI;
10use air::table_column::MasterMainColumn;
11use air::AIR;
12use arbitrary::Arbitrary;
13use indexmap::map::Entry::Occupied;
14use indexmap::map::Entry::Vacant;
15use indexmap::IndexMap;
16use isa::error::InstructionError;
17use isa::error::InstructionError::InstructionPointerOverflow;
18use isa::instruction::Instruction;
19use isa::program::Program;
20use itertools::Itertools;
21use ndarray::s;
22use ndarray::Array2;
23use ndarray::Axis;
24use strum::EnumCount;
25use strum::IntoEnumIterator;
26use twenty_first::prelude::*;
27
28use crate::table;
29use crate::table::op_stack::OpStackTableEntry;
30use crate::table::ram::RamTableCall;
31use crate::table::u32::U32TableEntry;
32use crate::vm::CoProcessorCall;
33use crate::vm::VMState;
34
35/// An Algebraic Execution Trace (AET) is the primary witness required for proof generation. It
36/// holds every intermediate state of the processor and all co-processors, alongside additional
37/// witness information, such as the number of times each instruction has been looked up
38/// (equivalently, how often each instruction has been executed).
39#[derive(Debug, Clone)]
40pub struct AlgebraicExecutionTrace {
41    /// The program that was executed in order to generate the trace.
42    pub program: Program,
43
44    /// The number of times each instruction has been executed.
45    ///
46    /// Each instruction in the `program` has one associated entry in `instruction_multiplicities`,
47    /// counting the number of times this specific instruction at that location in the program
48    /// memory has been executed.
49    pub instruction_multiplicities: Vec<u32>,
50
51    /// Records the state of the processor after each instruction.
52    pub processor_trace: Array2<BFieldElement>,
53
54    pub op_stack_underflow_trace: Array2<BFieldElement>,
55
56    pub ram_trace: Array2<BFieldElement>,
57
58    /// The trace of hashing the program whose execution generated this `AlgebraicExecutionTrace`.
59    /// The resulting digest
60    /// 1. ties a [`Proof`](crate::proof::Proof) to the program it was produced from, and
61    /// 1. is accessible to the program being executed.
62    pub program_hash_trace: Array2<BFieldElement>,
63
64    /// For the `hash` instruction, the hash trace records the internal state of the Tip5
65    /// permutation for each round.
66    pub hash_trace: Array2<BFieldElement>,
67
68    /// For the Sponge instructions, i.e., `sponge_init`, `sponge_absorb`,
69    /// `sponge_absorb_mem`, and `sponge_squeeze`, the Sponge trace records the
70    /// internal state of the Tip5 permutation for each round.
71    pub sponge_trace: Array2<BFieldElement>,
72
73    /// The u32 entries hold all pairs of BFieldElements that were written to the U32 Table,
74    /// alongside the u32 instruction that was executed at the time. Additionally, it records how
75    /// often the instruction was executed with these arguments.
76    // `IndexMap` over `HashMap` for deterministic iteration order. This is not
77    // needed for correctness of the STARK.
78    pub u32_entries: IndexMap<U32TableEntry, u64>,
79
80    /// Records how often each entry in the cascade table was looked up.
81    // `IndexMap` over `HashMap` for the same reasons as for field `u32_entries`.
82    pub cascade_table_lookup_multiplicities: IndexMap<u16, u64>,
83
84    /// Records how often each entry in the lookup table was looked up.
85    pub lookup_table_lookup_multiplicities: [u64; AlgebraicExecutionTrace::LOOKUP_TABLE_HEIGHT],
86}
87
88#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
89pub struct TableHeight {
90    pub table: TableId,
91    pub height: usize,
92}
93
94impl AlgebraicExecutionTrace {
95    pub(crate) const LOOKUP_TABLE_HEIGHT: usize = 1 << 8;
96
97    pub fn new(program: Program) -> Self {
98        const PROCESSOR_WIDTH: usize = <processor::ProcessorTable as AIR>::MainColumn::COUNT;
99        const OP_STACK_WIDTH: usize = <op_stack::OpStackTable as AIR>::MainColumn::COUNT;
100        const RAM_WIDTH: usize = <ram::RamTable as AIR>::MainColumn::COUNT;
101        const HASH_WIDTH: usize = <HashTable as AIR>::MainColumn::COUNT;
102
103        let program_len = program.len_bwords();
104
105        let mut aet = Self {
106            program,
107            instruction_multiplicities: vec![0_u32; program_len],
108            processor_trace: Array2::default([0, PROCESSOR_WIDTH]),
109            op_stack_underflow_trace: Array2::default([0, OP_STACK_WIDTH]),
110            ram_trace: Array2::default([0, RAM_WIDTH]),
111            program_hash_trace: Array2::default([0, HASH_WIDTH]),
112            hash_trace: Array2::default([0, HASH_WIDTH]),
113            sponge_trace: Array2::default([0, HASH_WIDTH]),
114            u32_entries: IndexMap::new(),
115            cascade_table_lookup_multiplicities: IndexMap::new(),
116            lookup_table_lookup_multiplicities: [0; Self::LOOKUP_TABLE_HEIGHT],
117        };
118        aet.fill_program_hash_trace();
119        aet
120    }
121
122    /// The height of the [AET](AlgebraicExecutionTrace) after [padding][pad].
123    ///
124    /// Guaranteed to be a power of two.
125    ///
126    /// [pad]: table::master_table::MasterMainTable::pad
127    pub fn padded_height(&self) -> usize {
128        self.height().height.next_power_of_two()
129    }
130
131    /// The height of the [AET](AlgebraicExecutionTrace) before [padding][pad].
132    /// Corresponds to the height of the longest table.
133    ///
134    /// [pad]: table::master_table::MasterMainTable::pad
135    pub fn height(&self) -> TableHeight {
136        TableId::iter()
137            .map(|t| TableHeight::new(t, self.height_of_table(t)))
138            .max()
139            .unwrap()
140    }
141
142    pub fn height_of_table(&self, table: TableId) -> usize {
143        let hash_table_height = || {
144            self.sponge_trace.nrows() + self.hash_trace.nrows() + self.program_hash_trace.nrows()
145        };
146
147        match table {
148            TableId::Program => Self::padded_program_length(&self.program),
149            TableId::Processor => self.processor_trace.nrows(),
150            TableId::OpStack => self.op_stack_underflow_trace.nrows(),
151            TableId::Ram => self.ram_trace.nrows(),
152            TableId::JumpStack => self.processor_trace.nrows(),
153            TableId::Hash => hash_table_height(),
154            TableId::Cascade => self.cascade_table_lookup_multiplicities.len(),
155            TableId::Lookup => Self::LOOKUP_TABLE_HEIGHT,
156            TableId::U32 => self.u32_table_height(),
157        }
158    }
159
160    /// # Panics
161    ///
162    /// - if the table height exceeds [`u32::MAX`]
163    /// - if the table height exceeds [`usize::MAX`]
164    fn u32_table_height(&self) -> usize {
165        let entry_len = U32TableEntry::table_height_contribution;
166        let height = self.u32_entries.keys().map(entry_len).sum::<u32>();
167        height.try_into().unwrap()
168    }
169
170    fn padded_program_length(program: &Program) -> usize {
171        // Padding is at least one 1.
172        // Also note that the Program Table's side of the instruction lookup argument requires at
173        // least one padding row to account for the processor's “next instruction or argument.”
174        // Both of these are captured by the “+ 1” in the following line.
175        (program.len_bwords() + 1).next_multiple_of(Tip5::RATE)
176    }
177
178    /// Hash the program and record the entire Sponge's trace for program attestation.
179    fn fill_program_hash_trace(&mut self) {
180        let padded_program = Self::hash_input_pad_program(&self.program);
181        let mut program_sponge = Tip5::init();
182        for chunk_to_absorb in padded_program.chunks(Tip5::RATE) {
183            program_sponge.state[..Tip5::RATE]
184                .iter_mut()
185                .zip_eq(chunk_to_absorb)
186                .for_each(|(sponge_state_elem, &absorb_elem)| *sponge_state_elem = absorb_elem);
187            let hash_trace = program_sponge.trace();
188            let trace_addendum = table::hash::trace_to_table_rows(hash_trace);
189
190            self.increase_lookup_multiplicities(hash_trace);
191            self.program_hash_trace
192                .append(Axis(0), trace_addendum.view())
193                .expect("shapes must be identical");
194        }
195
196        let instruction_column_index = CI.main_index();
197        let mut instruction_column = self.program_hash_trace.column_mut(instruction_column_index);
198        instruction_column.fill(Instruction::Hash.opcode_b());
199
200        // consistency check
201        let program_digest = program_sponge.state[..Digest::LEN].try_into().unwrap();
202        let program_digest = Digest::new(program_digest);
203        let expected_digest = self.program.hash();
204        assert_eq!(expected_digest, program_digest);
205    }
206
207    fn hash_input_pad_program(program: &Program) -> Vec<BFieldElement> {
208        let padded_program_length = Self::padded_program_length(program);
209
210        // padding is one 1, then as many zeros as necessary: [1, 0, 0, …]
211        let program_iter = program.to_bwords().into_iter();
212        let one = bfe_array![1];
213        let zeros = bfe_array![0; tip5::RATE];
214        program_iter
215            .chain(one)
216            .chain(zeros)
217            .take(padded_program_length)
218            .collect()
219    }
220
221    pub(crate) fn record_state(&mut self, state: &VMState) -> Result<(), InstructionError> {
222        self.record_instruction_lookup(state.instruction_pointer)?;
223        self.append_state_to_processor_trace(state);
224        Ok(())
225    }
226
227    fn record_instruction_lookup(
228        &mut self,
229        instruction_pointer: usize,
230    ) -> Result<(), InstructionError> {
231        if instruction_pointer >= self.instruction_multiplicities.len() {
232            return Err(InstructionPointerOverflow);
233        }
234        self.instruction_multiplicities[instruction_pointer] += 1;
235        Ok(())
236    }
237
238    fn append_state_to_processor_trace(&mut self, state: &VMState) {
239        self.processor_trace
240            .push_row(state.to_processor_row().view())
241            .unwrap()
242    }
243
244    pub(crate) fn record_co_processor_call(&mut self, co_processor_call: CoProcessorCall) {
245        match co_processor_call {
246            CoProcessorCall::Tip5Trace(Instruction::Hash, trace) => self.append_hash_trace(*trace),
247            CoProcessorCall::SpongeStateReset => self.append_initial_sponge_state(),
248            CoProcessorCall::Tip5Trace(instruction, trace) => {
249                self.append_sponge_trace(instruction, *trace)
250            }
251            CoProcessorCall::U32(u32_entry) => self.record_u32_table_entry(u32_entry),
252            CoProcessorCall::OpStack(op_stack_entry) => self.record_op_stack_entry(op_stack_entry),
253            CoProcessorCall::Ram(ram_call) => self.record_ram_call(ram_call),
254        }
255    }
256
257    fn append_hash_trace(&mut self, trace: PermutationTrace) {
258        self.increase_lookup_multiplicities(trace);
259        let mut hash_trace_addendum = table::hash::trace_to_table_rows(trace);
260        hash_trace_addendum
261            .slice_mut(s![.., CI.main_index()])
262            .fill(Instruction::Hash.opcode_b());
263        self.hash_trace
264            .append(Axis(0), hash_trace_addendum.view())
265            .expect("shapes must be identical");
266    }
267
268    fn append_initial_sponge_state(&mut self) {
269        let round_number = 0;
270        let initial_state = Tip5::init().state;
271        let mut hash_table_row = table::hash::trace_row_to_table_row(initial_state, round_number);
272        hash_table_row[CI.main_index()] = Instruction::SpongeInit.opcode_b();
273        self.sponge_trace.push_row(hash_table_row.view()).unwrap();
274    }
275
276    fn append_sponge_trace(&mut self, instruction: Instruction, trace: PermutationTrace) {
277        assert!(matches!(
278            instruction,
279            Instruction::SpongeAbsorb | Instruction::SpongeSqueeze
280        ));
281        self.increase_lookup_multiplicities(trace);
282        let mut sponge_trace_addendum = table::hash::trace_to_table_rows(trace);
283        sponge_trace_addendum
284            .slice_mut(s![.., CI.main_index()])
285            .fill(instruction.opcode_b());
286        self.sponge_trace
287            .append(Axis(0), sponge_trace_addendum.view())
288            .expect("shapes must be identical");
289    }
290
291    /// Given a trace of the hash function's permutation, determines how often each entry in the
292    /// - cascade table was looked up, and
293    /// - lookup table was looked up;
294    ///
295    /// and increases the multiplicities accordingly
296    fn increase_lookup_multiplicities(&mut self, trace: PermutationTrace) {
297        // The last row in the trace is the permutation's result: no lookups are performed for it.
298        let rows_for_which_lookups_are_performed = trace.iter().dropping_back(1);
299        for row in rows_for_which_lookups_are_performed {
300            self.increase_lookup_multiplicities_for_row(row);
301        }
302    }
303
304    /// Given one row of the hash function's permutation trace, increase the multiplicities of the
305    /// relevant entries in the cascade table and/or the lookup table.
306    fn increase_lookup_multiplicities_for_row(&mut self, row: &[BFieldElement; tip5::STATE_SIZE]) {
307        for &state_element in &row[0..tip5::NUM_SPLIT_AND_LOOKUP] {
308            self.increase_lookup_multiplicities_for_state_element(state_element);
309        }
310    }
311
312    /// Given one state element, increase the multiplicities of the corresponding entries in the
313    /// cascade table and/or the lookup table.
314    fn increase_lookup_multiplicities_for_state_element(&mut self, state_element: BFieldElement) {
315        for limb in table::hash::base_field_element_into_16_bit_limbs(state_element) {
316            match self.cascade_table_lookup_multiplicities.entry(limb) {
317                Occupied(mut cascade_table_entry) => *cascade_table_entry.get_mut() += 1,
318                Vacant(cascade_table_entry) => {
319                    cascade_table_entry.insert(1);
320                    self.increase_lookup_table_multiplicities_for_limb(limb);
321                }
322            }
323        }
324    }
325
326    /// Given one 16-bit limb, increase the multiplicities of the corresponding entries in the
327    /// lookup table.
328    fn increase_lookup_table_multiplicities_for_limb(&mut self, limb: u16) {
329        let limb_lo = limb & 0xff;
330        let limb_hi = (limb >> 8) & 0xff;
331        self.lookup_table_lookup_multiplicities[limb_lo as usize] += 1;
332        self.lookup_table_lookup_multiplicities[limb_hi as usize] += 1;
333    }
334
335    fn record_u32_table_entry(&mut self, u32_entry: U32TableEntry) {
336        self.u32_entries.entry(u32_entry).or_insert(0).add_assign(1)
337    }
338
339    fn record_op_stack_entry(&mut self, op_stack_entry: OpStackTableEntry) {
340        let op_stack_table_row = op_stack_entry.to_main_table_row();
341        self.op_stack_underflow_trace
342            .push_row(op_stack_table_row.view())
343            .unwrap();
344    }
345
346    fn record_ram_call(&mut self, ram_call: RamTableCall) {
347        self.ram_trace
348            .push_row(ram_call.to_table_row().view())
349            .unwrap();
350    }
351}
352
353impl TableHeight {
354    fn new(table: TableId, height: usize) -> Self {
355        Self { table, height }
356    }
357}
358
359impl PartialOrd for TableHeight {
360    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
361        Some(self.cmp(other))
362    }
363}
364
365impl Ord for TableHeight {
366    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
367        self.height.cmp(&other.height)
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use assert2::assert;
374    use isa::triton_asm;
375    use isa::triton_program;
376
377    use super::*;
378    use crate::prelude::*;
379
380    #[test]
381    fn pad_program_requiring_no_padding_zeros() {
382        let eight_nops = triton_asm![nop; 8];
383        let program = triton_program!({&eight_nops} halt);
384        let padded_program = AlgebraicExecutionTrace::hash_input_pad_program(&program);
385
386        let expected = [program.to_bwords(), bfe_vec![1]].concat();
387        assert!(expected == padded_program);
388    }
389
390    #[test]
391    fn height_of_any_table_can_be_computed() {
392        let program = triton_program!(halt);
393        let (aet, _) =
394            VM::trace_execution(program, PublicInput::default(), NonDeterminism::default())
395                .unwrap();
396
397        let _ = aet.height();
398        for table in TableId::iter() {
399            let _ = aet.height_of_table(table);
400        }
401    }
402}