triton_air/table/
hash.rs

1use constraint_circuit::ConstraintCircuitBuilder;
2use constraint_circuit::ConstraintCircuitMonad;
3use constraint_circuit::DualRowIndicator;
4use constraint_circuit::DualRowIndicator::CurrentAux;
5use constraint_circuit::DualRowIndicator::CurrentMain;
6use constraint_circuit::DualRowIndicator::NextAux;
7use constraint_circuit::DualRowIndicator::NextMain;
8use constraint_circuit::InputIndicator;
9use constraint_circuit::SingleRowIndicator;
10use constraint_circuit::SingleRowIndicator::Aux;
11use constraint_circuit::SingleRowIndicator::Main;
12use isa::instruction::Instruction;
13use itertools::Itertools;
14use strum::Display;
15use strum::EnumCount;
16use strum::EnumIter;
17use strum::IntoEnumIterator;
18use twenty_first::prelude::tip5::NUM_ROUNDS;
19use twenty_first::prelude::*;
20
21use crate::AIR;
22use crate::challenge_id::ChallengeId;
23use crate::cross_table_argument::CrossTableArg;
24use crate::cross_table_argument::EvalArg;
25use crate::cross_table_argument::LookupArg;
26use crate::table_column::MasterAuxColumn;
27use crate::table_column::MasterMainColumn;
28
29pub const MONTGOMERY_MODULUS: BFieldElement =
30    BFieldElement::new(((1_u128 << 64) % BFieldElement::P as u128) as u64);
31
32const POWER_MAP_EXPONENT: u64 = 7;
33const NUM_ROUND_CONSTANTS: usize = tip5::STATE_SIZE;
34
35pub const PERMUTATION_TRACE_LENGTH: usize = NUM_ROUNDS + 1;
36
37pub type PermutationTrace = [[BFieldElement; tip5::STATE_SIZE]; PERMUTATION_TRACE_LENGTH];
38
39#[derive(Debug, Copy, Clone, Eq, PartialEq)]
40pub struct HashTable;
41
42impl crate::private::Seal for HashTable {}
43
44type MainColumn = <HashTable as AIR>::MainColumn;
45type AuxColumn = <HashTable as AIR>::AuxColumn;
46
47impl HashTable {
48    /// Get the MDS matrix's entry in row `row_idx` and column `col_idx`.
49    const fn mds_matrix_entry(row_idx: usize, col_idx: usize) -> BFieldElement {
50        assert!(row_idx < tip5::STATE_SIZE);
51        assert!(col_idx < tip5::STATE_SIZE);
52        let index_in_matrix_defining_column =
53            (tip5::STATE_SIZE + row_idx - col_idx) % tip5::STATE_SIZE;
54        let mds_matrix_entry = tip5::MDS_MATRIX_FIRST_COLUMN[index_in_matrix_defining_column];
55        BFieldElement::new(mds_matrix_entry as u64)
56    }
57
58    /// The round constants for round `r` if it is a valid round number in the
59    /// Tip5 permutation, and the zero vector otherwise.
60    pub fn tip5_round_constants_by_round_number(r: usize) -> [BFieldElement; NUM_ROUND_CONSTANTS] {
61        if r >= NUM_ROUNDS {
62            return bfe_array![0; NUM_ROUND_CONSTANTS];
63        }
64
65        let range_start = NUM_ROUND_CONSTANTS * r;
66        let range_end = NUM_ROUND_CONSTANTS * (r + 1);
67        tip5::ROUND_CONSTANTS[range_start..range_end]
68            .try_into()
69            .unwrap()
70    }
71
72    /// Construct one of the states 0 through 3 from its constituent limbs.
73    /// For example, state 0 (prior to it being looked up in the
74    /// split-and-lookup S-Box, which is usually the desired version of the
75    /// state) is constructed from limbs [`State0HighestLkIn`][hi] through
76    /// [`State0LowestLkIn`][lo].
77    ///
78    /// States 4 through 15 are directly accessible. See also the slightly
79    /// related [`Self::state_column_by_index`].
80    ///
81    /// [hi]: crate::table_column::HashMainColumn::State0HighestLkIn
82    /// [lo]: crate::table_column::HashMainColumn::State0LowestLkIn
83    fn re_compose_16_bit_limbs<II: InputIndicator>(
84        circuit_builder: &ConstraintCircuitBuilder<II>,
85        highest: ConstraintCircuitMonad<II>,
86        mid_high: ConstraintCircuitMonad<II>,
87        mid_low: ConstraintCircuitMonad<II>,
88        lowest: ConstraintCircuitMonad<II>,
89    ) -> ConstraintCircuitMonad<II> {
90        let constant = |c: u64| circuit_builder.b_constant(c);
91        let montgomery_modulus_inv = circuit_builder.b_constant(MONTGOMERY_MODULUS.inverse());
92
93        let sum_of_shifted_limbs = highest * constant(1 << 48)
94            + mid_high * constant(1 << 32)
95            + mid_low * constant(1 << 16)
96            + lowest;
97        sum_of_shifted_limbs * montgomery_modulus_inv
98    }
99
100    /// A constraint circuit evaluating to zero if and only if the given
101    /// `round_number_circuit_node` is not equal to the given
102    /// `round_number_to_deselect`.
103    fn round_number_deselector<II: InputIndicator>(
104        circuit_builder: &ConstraintCircuitBuilder<II>,
105        round_number_circuit_node: &ConstraintCircuitMonad<II>,
106        round_number_to_deselect: usize,
107    ) -> ConstraintCircuitMonad<II> {
108        assert!(
109            round_number_to_deselect <= NUM_ROUNDS,
110            "Round number must be in [0, {NUM_ROUNDS}] but got {round_number_to_deselect}."
111        );
112        let constant = |c: u64| circuit_builder.b_constant(c);
113
114        // To not subtract zero from the first factor: some special casing.
115        let first_factor = match round_number_to_deselect {
116            0 => constant(1),
117            _ => round_number_circuit_node.clone(),
118        };
119        (1..=NUM_ROUNDS)
120            .filter(|&r| r != round_number_to_deselect)
121            .map(|r| round_number_circuit_node.clone() - constant(r as u64))
122            .fold(first_factor, |a, b| a * b)
123    }
124
125    /// A constraint circuit evaluating to zero if and only if the given
126    /// `mode_circuit_node` is equal to the given `mode_to_select`.
127    fn select_mode<II: InputIndicator>(
128        circuit_builder: &ConstraintCircuitBuilder<II>,
129        mode_circuit_node: &ConstraintCircuitMonad<II>,
130        mode_to_select: HashTableMode,
131    ) -> ConstraintCircuitMonad<II> {
132        mode_circuit_node.clone() - circuit_builder.b_constant(mode_to_select)
133    }
134
135    /// A constraint circuit evaluating to zero if and only if the given
136    /// `mode_circuit_node` is not equal to the given `mode_to_deselect`.
137    fn mode_deselector<II: InputIndicator>(
138        circuit_builder: &ConstraintCircuitBuilder<II>,
139        mode_circuit_node: &ConstraintCircuitMonad<II>,
140        mode_to_deselect: HashTableMode,
141    ) -> ConstraintCircuitMonad<II> {
142        let constant = |c: u64| circuit_builder.b_constant(c);
143        HashTableMode::iter()
144            .filter(|&mode| mode != mode_to_deselect)
145            .map(|mode| mode_circuit_node.clone() - constant(mode.into()))
146            .fold(constant(1), |accumulator, factor| accumulator * factor)
147    }
148
149    fn instruction_deselector<II: InputIndicator>(
150        circuit_builder: &ConstraintCircuitBuilder<II>,
151        current_instruction_node: &ConstraintCircuitMonad<II>,
152        instruction_to_deselect: Instruction,
153    ) -> ConstraintCircuitMonad<II> {
154        let constant = |c: u64| circuit_builder.b_constant(c);
155        let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b());
156        let relevant_instructions = [
157            Instruction::Hash,
158            Instruction::SpongeInit,
159            Instruction::SpongeAbsorb,
160            Instruction::SpongeSqueeze,
161        ];
162        assert!(relevant_instructions.contains(&instruction_to_deselect));
163
164        relevant_instructions
165            .iter()
166            .filter(|&instruction| instruction != &instruction_to_deselect)
167            .map(|&instruction| current_instruction_node.clone() - opcode(instruction))
168            .fold(constant(1), |accumulator, factor| accumulator * factor)
169    }
170
171    /// The [main column][main_col] for the round constant corresponding to the
172    /// given index. Valid indices are 0 through 15, corresponding to the 16
173    /// round constants [`Constant0`][c0] through [`Constant15`][c15].
174    ///
175    /// [main_col]: crate::table_column::HashMainColumn
176    /// [c0]: crate::table_column::HashMainColumn::Constant0
177    /// [c15]: crate::table_column::HashMainColumn::Constant15
178    pub fn round_constant_column_by_index(index: usize) -> MainColumn {
179        match index {
180            0 => MainColumn::Constant0,
181            1 => MainColumn::Constant1,
182            2 => MainColumn::Constant2,
183            3 => MainColumn::Constant3,
184            4 => MainColumn::Constant4,
185            5 => MainColumn::Constant5,
186            6 => MainColumn::Constant6,
187            7 => MainColumn::Constant7,
188            8 => MainColumn::Constant8,
189            9 => MainColumn::Constant9,
190            10 => MainColumn::Constant10,
191            11 => MainColumn::Constant11,
192            12 => MainColumn::Constant12,
193            13 => MainColumn::Constant13,
194            14 => MainColumn::Constant14,
195            15 => MainColumn::Constant15,
196            _ => panic!("invalid constant column index"),
197        }
198    }
199
200    /// The [`HashMainColumn`][MainColumn] for the state corresponding to the
201    /// given index. Valid indices are 4 through 15, corresponding to the 12
202    /// state columns [`State4`][state_4] through [`State15`][state_15].
203    ///
204    /// States with indices 0 through 3 have to be assembled from the respective
205    /// limbs; see [`Self::re_compose_states_0_through_3_before_lookup`]
206    /// or [`Self::re_compose_16_bit_limbs`].
207    ///
208    /// [state_4]: crate::table_column::HashMainColumn::State4
209    /// [state_15]: crate::table_column::HashMainColumn::State15
210    fn state_column_by_index(index: usize) -> MainColumn {
211        match index {
212            4 => MainColumn::State4,
213            5 => MainColumn::State5,
214            6 => MainColumn::State6,
215            7 => MainColumn::State7,
216            8 => MainColumn::State8,
217            9 => MainColumn::State9,
218            10 => MainColumn::State10,
219            11 => MainColumn::State11,
220            12 => MainColumn::State12,
221            13 => MainColumn::State13,
222            14 => MainColumn::State14,
223            15 => MainColumn::State15,
224            _ => panic!("invalid state column index"),
225        }
226    }
227
228    fn indicate_column_index_in_main_row(column: MainColumn) -> SingleRowIndicator {
229        Main(column.master_main_index())
230    }
231
232    fn indicate_column_index_in_current_main_row(column: MainColumn) -> DualRowIndicator {
233        CurrentMain(column.master_main_index())
234    }
235
236    fn indicate_column_index_in_next_main_row(column: MainColumn) -> DualRowIndicator {
237        NextMain(column.master_main_index())
238    }
239
240    fn re_compose_states_0_through_3_before_lookup<II: InputIndicator>(
241        circuit_builder: &ConstraintCircuitBuilder<II>,
242        main_row_to_input_indicator: fn(MainColumn) -> II,
243    ) -> [ConstraintCircuitMonad<II>; 4] {
244        let input = |input_indicator: II| circuit_builder.input(input_indicator);
245        let state_0 = Self::re_compose_16_bit_limbs(
246            circuit_builder,
247            input(main_row_to_input_indicator(MainColumn::State0HighestLkIn)),
248            input(main_row_to_input_indicator(MainColumn::State0MidHighLkIn)),
249            input(main_row_to_input_indicator(MainColumn::State0MidLowLkIn)),
250            input(main_row_to_input_indicator(MainColumn::State0LowestLkIn)),
251        );
252        let state_1 = Self::re_compose_16_bit_limbs(
253            circuit_builder,
254            input(main_row_to_input_indicator(MainColumn::State1HighestLkIn)),
255            input(main_row_to_input_indicator(MainColumn::State1MidHighLkIn)),
256            input(main_row_to_input_indicator(MainColumn::State1MidLowLkIn)),
257            input(main_row_to_input_indicator(MainColumn::State1LowestLkIn)),
258        );
259        let state_2 = Self::re_compose_16_bit_limbs(
260            circuit_builder,
261            input(main_row_to_input_indicator(MainColumn::State2HighestLkIn)),
262            input(main_row_to_input_indicator(MainColumn::State2MidHighLkIn)),
263            input(main_row_to_input_indicator(MainColumn::State2MidLowLkIn)),
264            input(main_row_to_input_indicator(MainColumn::State2LowestLkIn)),
265        );
266        let state_3 = Self::re_compose_16_bit_limbs(
267            circuit_builder,
268            input(main_row_to_input_indicator(MainColumn::State3HighestLkIn)),
269            input(main_row_to_input_indicator(MainColumn::State3MidHighLkIn)),
270            input(main_row_to_input_indicator(MainColumn::State3MidLowLkIn)),
271            input(main_row_to_input_indicator(MainColumn::State3LowestLkIn)),
272        );
273        [state_0, state_1, state_2, state_3]
274    }
275
276    fn tip5_constraints_as_circuits(
277        circuit_builder: &ConstraintCircuitBuilder<DualRowIndicator>,
278    ) -> (
279        [ConstraintCircuitMonad<DualRowIndicator>; tip5::STATE_SIZE],
280        [ConstraintCircuitMonad<DualRowIndicator>; tip5::STATE_SIZE],
281    ) {
282        let constant = |c: u64| circuit_builder.b_constant(c);
283        let b_constant = |c| circuit_builder.b_constant(c);
284        let current_main_row = |column_idx: MainColumn| {
285            circuit_builder.input(CurrentMain(column_idx.master_main_index()))
286        };
287        let next_main_row = |column_idx: MainColumn| {
288            circuit_builder.input(NextMain(column_idx.master_main_index()))
289        };
290
291        let state_0_after_lookup = Self::re_compose_16_bit_limbs(
292            circuit_builder,
293            current_main_row(MainColumn::State0HighestLkOut),
294            current_main_row(MainColumn::State0MidHighLkOut),
295            current_main_row(MainColumn::State0MidLowLkOut),
296            current_main_row(MainColumn::State0LowestLkOut),
297        );
298        let state_1_after_lookup = Self::re_compose_16_bit_limbs(
299            circuit_builder,
300            current_main_row(MainColumn::State1HighestLkOut),
301            current_main_row(MainColumn::State1MidHighLkOut),
302            current_main_row(MainColumn::State1MidLowLkOut),
303            current_main_row(MainColumn::State1LowestLkOut),
304        );
305        let state_2_after_lookup = Self::re_compose_16_bit_limbs(
306            circuit_builder,
307            current_main_row(MainColumn::State2HighestLkOut),
308            current_main_row(MainColumn::State2MidHighLkOut),
309            current_main_row(MainColumn::State2MidLowLkOut),
310            current_main_row(MainColumn::State2LowestLkOut),
311        );
312        let state_3_after_lookup = Self::re_compose_16_bit_limbs(
313            circuit_builder,
314            current_main_row(MainColumn::State3HighestLkOut),
315            current_main_row(MainColumn::State3MidHighLkOut),
316            current_main_row(MainColumn::State3MidLowLkOut),
317            current_main_row(MainColumn::State3LowestLkOut),
318        );
319
320        let state_part_before_power_map: [_; tip5::STATE_SIZE - tip5::NUM_SPLIT_AND_LOOKUP] = [
321            MainColumn::State4,
322            MainColumn::State5,
323            MainColumn::State6,
324            MainColumn::State7,
325            MainColumn::State8,
326            MainColumn::State9,
327            MainColumn::State10,
328            MainColumn::State11,
329            MainColumn::State12,
330            MainColumn::State13,
331            MainColumn::State14,
332            MainColumn::State15,
333        ]
334        .map(current_main_row);
335
336        let state_part_after_power_map = {
337            let mut exponentiation_accumulator = state_part_before_power_map.clone();
338            for _ in 1..POWER_MAP_EXPONENT {
339                for (i, state) in exponentiation_accumulator.iter_mut().enumerate() {
340                    *state = state.clone() * state_part_before_power_map[i].clone();
341                }
342            }
343            exponentiation_accumulator
344        };
345
346        let state_after_s_box_application = [
347            state_0_after_lookup,
348            state_1_after_lookup,
349            state_2_after_lookup,
350            state_3_after_lookup,
351            state_part_after_power_map[0].clone(),
352            state_part_after_power_map[1].clone(),
353            state_part_after_power_map[2].clone(),
354            state_part_after_power_map[3].clone(),
355            state_part_after_power_map[4].clone(),
356            state_part_after_power_map[5].clone(),
357            state_part_after_power_map[6].clone(),
358            state_part_after_power_map[7].clone(),
359            state_part_after_power_map[8].clone(),
360            state_part_after_power_map[9].clone(),
361            state_part_after_power_map[10].clone(),
362            state_part_after_power_map[11].clone(),
363        ];
364
365        let mut state_after_matrix_multiplication = vec![constant(0); tip5::STATE_SIZE];
366        for (row_idx, acc) in state_after_matrix_multiplication.iter_mut().enumerate() {
367            for (col_idx, state) in state_after_s_box_application.iter().enumerate() {
368                let matrix_entry = b_constant(Self::mds_matrix_entry(row_idx, col_idx));
369                *acc = acc.clone() + matrix_entry * state.clone();
370            }
371        }
372
373        let round_constants: [_; tip5::STATE_SIZE] = [
374            MainColumn::Constant0,
375            MainColumn::Constant1,
376            MainColumn::Constant2,
377            MainColumn::Constant3,
378            MainColumn::Constant4,
379            MainColumn::Constant5,
380            MainColumn::Constant6,
381            MainColumn::Constant7,
382            MainColumn::Constant8,
383            MainColumn::Constant9,
384            MainColumn::Constant10,
385            MainColumn::Constant11,
386            MainColumn::Constant12,
387            MainColumn::Constant13,
388            MainColumn::Constant14,
389            MainColumn::Constant15,
390        ]
391        .map(current_main_row);
392
393        let state_after_round_constant_addition = state_after_matrix_multiplication
394            .into_iter()
395            .zip_eq(round_constants)
396            .map(|(st, rndc)| st + rndc)
397            .collect_vec();
398
399        let [state_0_next, state_1_next, state_2_next, state_3_next] =
400            Self::re_compose_states_0_through_3_before_lookup(
401                circuit_builder,
402                Self::indicate_column_index_in_next_main_row,
403            );
404        let state_next = [
405            state_0_next,
406            state_1_next,
407            state_2_next,
408            state_3_next,
409            next_main_row(MainColumn::State4),
410            next_main_row(MainColumn::State5),
411            next_main_row(MainColumn::State6),
412            next_main_row(MainColumn::State7),
413            next_main_row(MainColumn::State8),
414            next_main_row(MainColumn::State9),
415            next_main_row(MainColumn::State10),
416            next_main_row(MainColumn::State11),
417            next_main_row(MainColumn::State12),
418            next_main_row(MainColumn::State13),
419            next_main_row(MainColumn::State14),
420            next_main_row(MainColumn::State15),
421        ];
422
423        let round_number_next = next_main_row(MainColumn::RoundNumber);
424        let hash_function_round_correctly_performs_update = state_after_round_constant_addition
425            .into_iter()
426            .zip_eq(state_next.clone())
427            .map(|(state_element, state_element_next)| {
428                round_number_next.clone() * (state_element - state_element_next)
429            })
430            .collect_vec()
431            .try_into()
432            .unwrap();
433
434        (state_next, hash_function_round_correctly_performs_update)
435    }
436
437    fn cascade_log_derivative_update_circuit(
438        circuit_builder: &ConstraintCircuitBuilder<DualRowIndicator>,
439        look_in_column: MainColumn,
440        look_out_column: MainColumn,
441        cascade_log_derivative_column: AuxColumn,
442    ) -> ConstraintCircuitMonad<DualRowIndicator> {
443        let challenge = |c| circuit_builder.challenge(c);
444        let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b());
445        let constant = |c: u32| circuit_builder.b_constant(c);
446        let next_main_row = |column_idx: MainColumn| {
447            circuit_builder.input(NextMain(column_idx.master_main_index()))
448        };
449        let current_aux_row = |column_idx: AuxColumn| {
450            circuit_builder.input(CurrentAux(column_idx.master_aux_index()))
451        };
452        let next_aux_row =
453            |column_idx: AuxColumn| circuit_builder.input(NextAux(column_idx.master_aux_index()));
454
455        let cascade_indeterminate = challenge(ChallengeId::HashCascadeLookupIndeterminate);
456        let look_in_weight = challenge(ChallengeId::HashCascadeLookInWeight);
457        let look_out_weight = challenge(ChallengeId::HashCascadeLookOutWeight);
458
459        let ci_next = next_main_row(MainColumn::CI);
460        let mode_next = next_main_row(MainColumn::Mode);
461        let round_number_next = next_main_row(MainColumn::RoundNumber);
462        let cascade_log_derivative = current_aux_row(cascade_log_derivative_column);
463        let cascade_log_derivative_next = next_aux_row(cascade_log_derivative_column);
464
465        let compressed_row = look_in_weight * next_main_row(look_in_column)
466            + look_out_weight * next_main_row(look_out_column);
467
468        let cascade_log_derivative_remains =
469            cascade_log_derivative_next.clone() - cascade_log_derivative.clone();
470        let cascade_log_derivative_updates = (cascade_log_derivative_next - cascade_log_derivative)
471            * (cascade_indeterminate - compressed_row)
472            - constant(1);
473
474        let next_row_is_padding_row_or_round_number_next_is_max_or_ci_next_is_sponge_init =
475            Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad)
476                * (round_number_next.clone() - constant(NUM_ROUNDS as u32))
477                * (ci_next.clone() - opcode(Instruction::SpongeInit));
478        let round_number_next_is_not_num_rounds =
479            Self::round_number_deselector(circuit_builder, &round_number_next, NUM_ROUNDS);
480        let current_instruction_next_is_not_sponge_init =
481            Self::instruction_deselector(circuit_builder, &ci_next, Instruction::SpongeInit);
482
483        next_row_is_padding_row_or_round_number_next_is_max_or_ci_next_is_sponge_init
484            * cascade_log_derivative_updates
485            + round_number_next_is_not_num_rounds * cascade_log_derivative_remains.clone()
486            + current_instruction_next_is_not_sponge_init * cascade_log_derivative_remains
487    }
488}
489
490impl AIR for HashTable {
491    type MainColumn = crate::table_column::HashMainColumn;
492    type AuxColumn = crate::table_column::HashAuxColumn;
493
494    fn initial_constraints(
495        circuit_builder: &ConstraintCircuitBuilder<SingleRowIndicator>,
496    ) -> Vec<ConstraintCircuitMonad<SingleRowIndicator>> {
497        let challenge = |c| circuit_builder.challenge(c);
498        let constant = |c: u64| circuit_builder.b_constant(c);
499
500        let main_row =
501            |column: Self::MainColumn| circuit_builder.input(Main(column.master_main_index()));
502        let aux_row =
503            |column: Self::AuxColumn| circuit_builder.input(Aux(column.master_aux_index()));
504
505        let running_evaluation_initial = circuit_builder.x_constant(EvalArg::default_initial());
506        let lookup_arg_default_initial = circuit_builder.x_constant(LookupArg::default_initial());
507
508        let mode = main_row(Self::MainColumn::Mode);
509        let running_evaluation_hash_input = aux_row(Self::AuxColumn::HashInputRunningEvaluation);
510        let running_evaluation_hash_digest = aux_row(Self::AuxColumn::HashDigestRunningEvaluation);
511        let running_evaluation_sponge = aux_row(Self::AuxColumn::SpongeRunningEvaluation);
512        let running_evaluation_receive_chunk =
513            aux_row(Self::AuxColumn::ReceiveChunkRunningEvaluation);
514
515        let cascade_indeterminate = challenge(ChallengeId::HashCascadeLookupIndeterminate);
516        let look_in_weight = challenge(ChallengeId::HashCascadeLookInWeight);
517        let look_out_weight = challenge(ChallengeId::HashCascadeLookOutWeight);
518        let prepare_chunk_indeterminate =
519            challenge(ChallengeId::ProgramAttestationPrepareChunkIndeterminate);
520        let receive_chunk_indeterminate =
521            challenge(ChallengeId::ProgramAttestationSendChunkIndeterminate);
522
523        // First chunk of the program is received correctly. Relates to program
524        // attestation.
525        let [state_0, state_1, state_2, state_3] =
526            Self::re_compose_states_0_through_3_before_lookup(
527                circuit_builder,
528                Self::indicate_column_index_in_main_row,
529            );
530        let state_rate_part: [_; tip5::RATE] = [
531            state_0,
532            state_1,
533            state_2,
534            state_3,
535            main_row(Self::MainColumn::State4),
536            main_row(Self::MainColumn::State5),
537            main_row(Self::MainColumn::State6),
538            main_row(Self::MainColumn::State7),
539            main_row(Self::MainColumn::State8),
540            main_row(Self::MainColumn::State9),
541        ];
542        let compressed_chunk = state_rate_part
543            .into_iter()
544            .fold(running_evaluation_initial.clone(), |acc, state_element| {
545                acc * prepare_chunk_indeterminate.clone() + state_element
546            });
547        let running_evaluation_receive_chunk_is_initialized_correctly =
548            running_evaluation_receive_chunk
549                - receive_chunk_indeterminate * running_evaluation_initial.clone()
550                - compressed_chunk;
551
552        // The lookup arguments with the Cascade Table for the S-Boxes are
553        // initialized correctly.
554        let cascade_log_derivative_init_circuit =
555            |look_in_column, look_out_column, cascade_log_derivative_column| {
556                let look_in = main_row(look_in_column);
557                let look_out = main_row(look_out_column);
558                let compressed_row =
559                    look_in_weight.clone() * look_in + look_out_weight.clone() * look_out;
560                let cascade_log_derivative = aux_row(cascade_log_derivative_column);
561                (cascade_log_derivative - lookup_arg_default_initial.clone())
562                    * (cascade_indeterminate.clone() - compressed_row)
563                    - constant(1)
564            };
565
566        // miscellaneous initial constraints
567        let mode_is_program_hashing =
568            Self::select_mode(circuit_builder, &mode, HashTableMode::ProgramHashing);
569        let round_number_is_0 = main_row(Self::MainColumn::RoundNumber);
570        let running_evaluation_hash_input_is_default_initial =
571            running_evaluation_hash_input - running_evaluation_initial.clone();
572        let running_evaluation_hash_digest_is_default_initial =
573            running_evaluation_hash_digest - running_evaluation_initial.clone();
574        let running_evaluation_sponge_is_default_initial =
575            running_evaluation_sponge - running_evaluation_initial;
576
577        vec![
578            mode_is_program_hashing,
579            round_number_is_0,
580            running_evaluation_hash_input_is_default_initial,
581            running_evaluation_hash_digest_is_default_initial,
582            running_evaluation_sponge_is_default_initial,
583            running_evaluation_receive_chunk_is_initialized_correctly,
584            cascade_log_derivative_init_circuit(
585                Self::MainColumn::State0HighestLkIn,
586                Self::MainColumn::State0HighestLkOut,
587                Self::AuxColumn::CascadeState0HighestClientLogDerivative,
588            ),
589            cascade_log_derivative_init_circuit(
590                Self::MainColumn::State0MidHighLkIn,
591                Self::MainColumn::State0MidHighLkOut,
592                Self::AuxColumn::CascadeState0MidHighClientLogDerivative,
593            ),
594            cascade_log_derivative_init_circuit(
595                Self::MainColumn::State0MidLowLkIn,
596                Self::MainColumn::State0MidLowLkOut,
597                Self::AuxColumn::CascadeState0MidLowClientLogDerivative,
598            ),
599            cascade_log_derivative_init_circuit(
600                Self::MainColumn::State0LowestLkIn,
601                Self::MainColumn::State0LowestLkOut,
602                Self::AuxColumn::CascadeState0LowestClientLogDerivative,
603            ),
604            cascade_log_derivative_init_circuit(
605                Self::MainColumn::State1HighestLkIn,
606                Self::MainColumn::State1HighestLkOut,
607                Self::AuxColumn::CascadeState1HighestClientLogDerivative,
608            ),
609            cascade_log_derivative_init_circuit(
610                Self::MainColumn::State1MidHighLkIn,
611                Self::MainColumn::State1MidHighLkOut,
612                Self::AuxColumn::CascadeState1MidHighClientLogDerivative,
613            ),
614            cascade_log_derivative_init_circuit(
615                Self::MainColumn::State1MidLowLkIn,
616                Self::MainColumn::State1MidLowLkOut,
617                Self::AuxColumn::CascadeState1MidLowClientLogDerivative,
618            ),
619            cascade_log_derivative_init_circuit(
620                Self::MainColumn::State1LowestLkIn,
621                Self::MainColumn::State1LowestLkOut,
622                Self::AuxColumn::CascadeState1LowestClientLogDerivative,
623            ),
624            cascade_log_derivative_init_circuit(
625                Self::MainColumn::State2HighestLkIn,
626                Self::MainColumn::State2HighestLkOut,
627                Self::AuxColumn::CascadeState2HighestClientLogDerivative,
628            ),
629            cascade_log_derivative_init_circuit(
630                Self::MainColumn::State2MidHighLkIn,
631                Self::MainColumn::State2MidHighLkOut,
632                Self::AuxColumn::CascadeState2MidHighClientLogDerivative,
633            ),
634            cascade_log_derivative_init_circuit(
635                Self::MainColumn::State2MidLowLkIn,
636                Self::MainColumn::State2MidLowLkOut,
637                Self::AuxColumn::CascadeState2MidLowClientLogDerivative,
638            ),
639            cascade_log_derivative_init_circuit(
640                Self::MainColumn::State2LowestLkIn,
641                Self::MainColumn::State2LowestLkOut,
642                Self::AuxColumn::CascadeState2LowestClientLogDerivative,
643            ),
644            cascade_log_derivative_init_circuit(
645                Self::MainColumn::State3HighestLkIn,
646                Self::MainColumn::State3HighestLkOut,
647                Self::AuxColumn::CascadeState3HighestClientLogDerivative,
648            ),
649            cascade_log_derivative_init_circuit(
650                Self::MainColumn::State3MidHighLkIn,
651                Self::MainColumn::State3MidHighLkOut,
652                Self::AuxColumn::CascadeState3MidHighClientLogDerivative,
653            ),
654            cascade_log_derivative_init_circuit(
655                Self::MainColumn::State3MidLowLkIn,
656                Self::MainColumn::State3MidLowLkOut,
657                Self::AuxColumn::CascadeState3MidLowClientLogDerivative,
658            ),
659            cascade_log_derivative_init_circuit(
660                Self::MainColumn::State3LowestLkIn,
661                Self::MainColumn::State3LowestLkOut,
662                Self::AuxColumn::CascadeState3LowestClientLogDerivative,
663            ),
664        ]
665    }
666
667    fn consistency_constraints(
668        circuit_builder: &ConstraintCircuitBuilder<SingleRowIndicator>,
669    ) -> Vec<ConstraintCircuitMonad<SingleRowIndicator>> {
670        let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b());
671        let constant = |c: u64| circuit_builder.b_constant(c);
672        let main_row = |column_id: Self::MainColumn| {
673            circuit_builder.input(Main(column_id.master_main_index()))
674        };
675
676        let mode = main_row(Self::MainColumn::Mode);
677        let ci = main_row(Self::MainColumn::CI);
678        let round_number = main_row(Self::MainColumn::RoundNumber);
679
680        let ci_is_hash = ci.clone() - opcode(Instruction::Hash);
681        let ci_is_sponge_init = ci.clone() - opcode(Instruction::SpongeInit);
682        let ci_is_sponge_absorb = ci.clone() - opcode(Instruction::SpongeAbsorb);
683        let ci_is_sponge_squeeze = ci - opcode(Instruction::SpongeSqueeze);
684
685        let mode_is_not_hash = Self::mode_deselector(circuit_builder, &mode, HashTableMode::Hash);
686        let round_number_is_not_0 =
687            Self::round_number_deselector(circuit_builder, &round_number, 0);
688
689        let mode_is_a_valid_mode =
690            Self::mode_deselector(circuit_builder, &mode, HashTableMode::Pad)
691                * Self::select_mode(circuit_builder, &mode, HashTableMode::Pad);
692
693        let if_mode_is_not_sponge_then_ci_is_hash =
694            Self::select_mode(circuit_builder, &mode, HashTableMode::Sponge) * ci_is_hash.clone();
695
696        let if_mode_is_sponge_then_ci_is_a_sponge_instruction =
697            Self::mode_deselector(circuit_builder, &mode, HashTableMode::Sponge)
698                * ci_is_sponge_init
699                * ci_is_sponge_absorb.clone()
700                * ci_is_sponge_squeeze.clone();
701
702        let if_padding_mode_then_round_number_is_0 =
703            Self::mode_deselector(circuit_builder, &mode, HashTableMode::Pad)
704                * round_number.clone();
705
706        let if_ci_is_sponge_init_then_ = ci_is_hash * ci_is_sponge_absorb * ci_is_sponge_squeeze;
707        let if_ci_is_sponge_init_then_round_number_is_0 =
708            if_ci_is_sponge_init_then_.clone() * round_number.clone();
709
710        let if_ci_is_sponge_init_then_rate_is_0 = (10..=15).map(|state_index| {
711            let state_element = main_row(Self::state_column_by_index(state_index));
712            if_ci_is_sponge_init_then_.clone() * state_element
713        });
714
715        let if_mode_is_hash_and_round_no_is_0_then_ = round_number_is_not_0 * mode_is_not_hash;
716        let if_mode_is_hash_and_round_no_is_0_then_states_10_through_15_are_1 =
717            (10..=15).map(|state_index| {
718                let state_element = main_row(Self::state_column_by_index(state_index));
719                if_mode_is_hash_and_round_no_is_0_then_.clone() * (state_element - constant(1))
720            });
721
722        // consistency of the inverse of the highest 2 limbs minus 2^32 - 1
723        let one = constant(1);
724        let two_pow_16 = constant(1 << 16);
725        let two_pow_32 = constant(1 << 32);
726        let state_0_hi_limbs_minus_2_pow_32 = two_pow_32.clone()
727            - one.clone()
728            - main_row(Self::MainColumn::State0HighestLkIn) * two_pow_16.clone()
729            - main_row(Self::MainColumn::State0MidHighLkIn);
730        let state_1_hi_limbs_minus_2_pow_32 = two_pow_32.clone()
731            - one.clone()
732            - main_row(Self::MainColumn::State1HighestLkIn) * two_pow_16.clone()
733            - main_row(Self::MainColumn::State1MidHighLkIn);
734        let state_2_hi_limbs_minus_2_pow_32 = two_pow_32.clone()
735            - one.clone()
736            - main_row(Self::MainColumn::State2HighestLkIn) * two_pow_16.clone()
737            - main_row(Self::MainColumn::State2MidHighLkIn);
738        let state_3_hi_limbs_minus_2_pow_32 = two_pow_32
739            - one.clone()
740            - main_row(Self::MainColumn::State3HighestLkIn) * two_pow_16.clone()
741            - main_row(Self::MainColumn::State3MidHighLkIn);
742
743        let state_0_hi_limbs_inv = main_row(Self::MainColumn::State0Inv);
744        let state_1_hi_limbs_inv = main_row(Self::MainColumn::State1Inv);
745        let state_2_hi_limbs_inv = main_row(Self::MainColumn::State2Inv);
746        let state_3_hi_limbs_inv = main_row(Self::MainColumn::State3Inv);
747
748        let state_0_hi_limbs_are_not_all_1s =
749            state_0_hi_limbs_minus_2_pow_32.clone() * state_0_hi_limbs_inv.clone() - one.clone();
750        let state_1_hi_limbs_are_not_all_1s =
751            state_1_hi_limbs_minus_2_pow_32.clone() * state_1_hi_limbs_inv.clone() - one.clone();
752        let state_2_hi_limbs_are_not_all_1s =
753            state_2_hi_limbs_minus_2_pow_32.clone() * state_2_hi_limbs_inv.clone() - one.clone();
754        let state_3_hi_limbs_are_not_all_1s =
755            state_3_hi_limbs_minus_2_pow_32.clone() * state_3_hi_limbs_inv.clone() - one;
756
757        let state_0_hi_limbs_inv_is_inv_or_is_zero =
758            state_0_hi_limbs_are_not_all_1s.clone() * state_0_hi_limbs_inv;
759        let state_1_hi_limbs_inv_is_inv_or_is_zero =
760            state_1_hi_limbs_are_not_all_1s.clone() * state_1_hi_limbs_inv;
761        let state_2_hi_limbs_inv_is_inv_or_is_zero =
762            state_2_hi_limbs_are_not_all_1s.clone() * state_2_hi_limbs_inv;
763        let state_3_hi_limbs_inv_is_inv_or_is_zero =
764            state_3_hi_limbs_are_not_all_1s.clone() * state_3_hi_limbs_inv;
765
766        let state_0_hi_limbs_inv_is_inv_or_state_0_hi_limbs_is_zero =
767            state_0_hi_limbs_are_not_all_1s.clone() * state_0_hi_limbs_minus_2_pow_32;
768        let state_1_hi_limbs_inv_is_inv_or_state_1_hi_limbs_is_zero =
769            state_1_hi_limbs_are_not_all_1s.clone() * state_1_hi_limbs_minus_2_pow_32;
770        let state_2_hi_limbs_inv_is_inv_or_state_2_hi_limbs_is_zero =
771            state_2_hi_limbs_are_not_all_1s.clone() * state_2_hi_limbs_minus_2_pow_32;
772        let state_3_hi_limbs_inv_is_inv_or_state_3_hi_limbs_is_zero =
773            state_3_hi_limbs_are_not_all_1s.clone() * state_3_hi_limbs_minus_2_pow_32;
774
775        // consistent decomposition into limbs
776        let state_0_lo_limbs = main_row(Self::MainColumn::State0MidLowLkIn) * two_pow_16.clone()
777            + main_row(Self::MainColumn::State0LowestLkIn);
778        let state_1_lo_limbs = main_row(Self::MainColumn::State1MidLowLkIn) * two_pow_16.clone()
779            + main_row(Self::MainColumn::State1LowestLkIn);
780        let state_2_lo_limbs = main_row(Self::MainColumn::State2MidLowLkIn) * two_pow_16.clone()
781            + main_row(Self::MainColumn::State2LowestLkIn);
782        let state_3_lo_limbs = main_row(Self::MainColumn::State3MidLowLkIn) * two_pow_16
783            + main_row(Self::MainColumn::State3LowestLkIn);
784
785        let if_state_0_hi_limbs_are_all_1_then_state_0_lo_limbs_are_all_0 =
786            state_0_hi_limbs_are_not_all_1s * state_0_lo_limbs;
787        let if_state_1_hi_limbs_are_all_1_then_state_1_lo_limbs_are_all_0 =
788            state_1_hi_limbs_are_not_all_1s * state_1_lo_limbs;
789        let if_state_2_hi_limbs_are_all_1_then_state_2_lo_limbs_are_all_0 =
790            state_2_hi_limbs_are_not_all_1s * state_2_lo_limbs;
791        let if_state_3_hi_limbs_are_all_1_then_state_3_lo_limbs_are_all_0 =
792            state_3_hi_limbs_are_not_all_1s * state_3_lo_limbs;
793
794        let mut constraints = vec![
795            mode_is_a_valid_mode,
796            if_mode_is_not_sponge_then_ci_is_hash,
797            if_mode_is_sponge_then_ci_is_a_sponge_instruction,
798            if_padding_mode_then_round_number_is_0,
799            if_ci_is_sponge_init_then_round_number_is_0,
800            state_0_hi_limbs_inv_is_inv_or_is_zero,
801            state_1_hi_limbs_inv_is_inv_or_is_zero,
802            state_2_hi_limbs_inv_is_inv_or_is_zero,
803            state_3_hi_limbs_inv_is_inv_or_is_zero,
804            state_0_hi_limbs_inv_is_inv_or_state_0_hi_limbs_is_zero,
805            state_1_hi_limbs_inv_is_inv_or_state_1_hi_limbs_is_zero,
806            state_2_hi_limbs_inv_is_inv_or_state_2_hi_limbs_is_zero,
807            state_3_hi_limbs_inv_is_inv_or_state_3_hi_limbs_is_zero,
808            if_state_0_hi_limbs_are_all_1_then_state_0_lo_limbs_are_all_0,
809            if_state_1_hi_limbs_are_all_1_then_state_1_lo_limbs_are_all_0,
810            if_state_2_hi_limbs_are_all_1_then_state_2_lo_limbs_are_all_0,
811            if_state_3_hi_limbs_are_all_1_then_state_3_lo_limbs_are_all_0,
812        ];
813
814        constraints.extend(if_ci_is_sponge_init_then_rate_is_0);
815        constraints.extend(if_mode_is_hash_and_round_no_is_0_then_states_10_through_15_are_1);
816
817        for round_constant_column_idx in 0..NUM_ROUND_CONSTANTS {
818            let round_constant_column =
819                Self::round_constant_column_by_index(round_constant_column_idx);
820            let round_constant_column_circuit = main_row(round_constant_column);
821            let mut round_constant_constraint_circuit = constant(0);
822            for round_idx in 0..NUM_ROUNDS {
823                let round_constants = Self::tip5_round_constants_by_round_number(round_idx);
824                let round_constant = round_constants[round_constant_column_idx];
825                let round_constant = circuit_builder.b_constant(round_constant);
826                let round_deselector_circuit =
827                    Self::round_number_deselector(circuit_builder, &round_number, round_idx);
828                round_constant_constraint_circuit = round_constant_constraint_circuit
829                    + round_deselector_circuit
830                        * (round_constant_column_circuit.clone() - round_constant);
831            }
832            constraints.push(round_constant_constraint_circuit);
833        }
834
835        constraints
836    }
837
838    fn transition_constraints(
839        circuit_builder: &ConstraintCircuitBuilder<DualRowIndicator>,
840    ) -> Vec<ConstraintCircuitMonad<DualRowIndicator>> {
841        let challenge = |c| circuit_builder.challenge(c);
842        let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b());
843        let constant = |c: u64| circuit_builder.b_constant(c);
844
845        let opcode_hash = opcode(Instruction::Hash);
846        let opcode_sponge_init = opcode(Instruction::SpongeInit);
847        let opcode_sponge_absorb = opcode(Instruction::SpongeAbsorb);
848        let opcode_sponge_squeeze = opcode(Instruction::SpongeSqueeze);
849
850        let current_main_row = |column_idx: Self::MainColumn| {
851            circuit_builder.input(CurrentMain(column_idx.master_main_index()))
852        };
853        let next_main_row = |column_idx: Self::MainColumn| {
854            circuit_builder.input(NextMain(column_idx.master_main_index()))
855        };
856        let current_aux_row = |column_idx: Self::AuxColumn| {
857            circuit_builder.input(CurrentAux(column_idx.master_aux_index()))
858        };
859        let next_aux_row = |column_idx: Self::AuxColumn| {
860            circuit_builder.input(NextAux(column_idx.master_aux_index()))
861        };
862
863        let running_evaluation_initial = circuit_builder.x_constant(EvalArg::default_initial());
864
865        let prepare_chunk_indeterminate =
866            challenge(ChallengeId::ProgramAttestationPrepareChunkIndeterminate);
867        let receive_chunk_indeterminate =
868            challenge(ChallengeId::ProgramAttestationSendChunkIndeterminate);
869        let compress_program_digest_indeterminate =
870            challenge(ChallengeId::CompressProgramDigestIndeterminate);
871        let expected_program_digest = challenge(ChallengeId::CompressedProgramDigest);
872        let hash_input_eval_indeterminate = challenge(ChallengeId::HashInputIndeterminate);
873        let hash_digest_eval_indeterminate = challenge(ChallengeId::HashDigestIndeterminate);
874        let sponge_indeterminate = challenge(ChallengeId::SpongeIndeterminate);
875
876        let mode = current_main_row(Self::MainColumn::Mode);
877        let ci = current_main_row(Self::MainColumn::CI);
878        let round_number = current_main_row(Self::MainColumn::RoundNumber);
879        let running_evaluation_receive_chunk =
880            current_aux_row(Self::AuxColumn::ReceiveChunkRunningEvaluation);
881        let running_evaluation_hash_input =
882            current_aux_row(Self::AuxColumn::HashInputRunningEvaluation);
883        let running_evaluation_hash_digest =
884            current_aux_row(Self::AuxColumn::HashDigestRunningEvaluation);
885        let running_evaluation_sponge = current_aux_row(Self::AuxColumn::SpongeRunningEvaluation);
886
887        let mode_next = next_main_row(Self::MainColumn::Mode);
888        let ci_next = next_main_row(Self::MainColumn::CI);
889        let round_number_next = next_main_row(Self::MainColumn::RoundNumber);
890        let running_evaluation_receive_chunk_next =
891            next_aux_row(Self::AuxColumn::ReceiveChunkRunningEvaluation);
892        let running_evaluation_hash_input_next =
893            next_aux_row(Self::AuxColumn::HashInputRunningEvaluation);
894        let running_evaluation_hash_digest_next =
895            next_aux_row(Self::AuxColumn::HashDigestRunningEvaluation);
896        let running_evaluation_sponge_next = next_aux_row(Self::AuxColumn::SpongeRunningEvaluation);
897
898        let [state_0, state_1, state_2, state_3] =
899            Self::re_compose_states_0_through_3_before_lookup(
900                circuit_builder,
901                Self::indicate_column_index_in_current_main_row,
902            );
903
904        let state_current = [
905            state_0,
906            state_1,
907            state_2,
908            state_3,
909            current_main_row(Self::MainColumn::State4),
910            current_main_row(Self::MainColumn::State5),
911            current_main_row(Self::MainColumn::State6),
912            current_main_row(Self::MainColumn::State7),
913            current_main_row(Self::MainColumn::State8),
914            current_main_row(Self::MainColumn::State9),
915            current_main_row(Self::MainColumn::State10),
916            current_main_row(Self::MainColumn::State11),
917            current_main_row(Self::MainColumn::State12),
918            current_main_row(Self::MainColumn::State13),
919            current_main_row(Self::MainColumn::State14),
920            current_main_row(Self::MainColumn::State15),
921        ];
922
923        let (state_next, hash_function_round_correctly_performs_update) =
924            Self::tip5_constraints_as_circuits(circuit_builder);
925
926        let state_weights = [
927            ChallengeId::StackWeight0,
928            ChallengeId::StackWeight1,
929            ChallengeId::StackWeight2,
930            ChallengeId::StackWeight3,
931            ChallengeId::StackWeight4,
932            ChallengeId::StackWeight5,
933            ChallengeId::StackWeight6,
934            ChallengeId::StackWeight7,
935            ChallengeId::StackWeight8,
936            ChallengeId::StackWeight9,
937            ChallengeId::StackWeight10,
938            ChallengeId::StackWeight11,
939            ChallengeId::StackWeight12,
940            ChallengeId::StackWeight13,
941            ChallengeId::StackWeight14,
942            ChallengeId::StackWeight15,
943        ]
944        .map(challenge);
945
946        let round_number_is_not_num_rounds =
947            Self::round_number_deselector(circuit_builder, &round_number, NUM_ROUNDS);
948
949        let round_number_is_0_through_4_or_round_number_next_is_0 =
950            round_number_is_not_num_rounds * round_number_next.clone();
951
952        let next_mode_is_padding_mode_or_round_number_is_num_rounds_or_increments_by_one =
953            Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad)
954                * (ci.clone() - opcode_sponge_init.clone())
955                * (round_number.clone() - constant(NUM_ROUNDS as u64))
956                * (round_number_next.clone() - round_number.clone() - constant(1));
957
958        // compress the digest by computing the terminal of an evaluation
959        // argument
960        let compressed_digest = state_current[..Digest::LEN].iter().fold(
961            running_evaluation_initial.clone(),
962            |acc, digest_element| {
963                acc * compress_program_digest_indeterminate.clone() + digest_element.clone()
964            },
965        );
966        let if_mode_changes_from_program_hashing_then_current_digest_is_expected_program_digest =
967            Self::mode_deselector(circuit_builder, &mode, HashTableMode::ProgramHashing)
968                * Self::select_mode(circuit_builder, &mode_next, HashTableMode::ProgramHashing)
969                * (compressed_digest - expected_program_digest);
970
971        let if_mode_is_program_hashing_and_next_mode_is_sponge_then_ci_next_is_sponge_init =
972            Self::mode_deselector(circuit_builder, &mode, HashTableMode::ProgramHashing)
973                * Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::Sponge)
974                * (ci_next.clone() - opcode_sponge_init.clone());
975
976        let if_round_number_is_not_max_and_ci_is_not_sponge_init_then_ci_doesnt_change =
977            (round_number.clone() - constant(NUM_ROUNDS as u64))
978                * (ci.clone() - opcode_sponge_init.clone())
979                * (ci_next.clone() - ci.clone());
980
981        let if_round_number_is_not_max_and_ci_is_not_sponge_init_then_mode_doesnt_change =
982            (round_number - constant(NUM_ROUNDS as u64))
983                * (ci.clone() - opcode_sponge_init.clone())
984                * (mode_next.clone() - mode.clone());
985
986        let if_mode_is_sponge_then_mode_next_is_sponge_or_hash_or_pad =
987            Self::mode_deselector(circuit_builder, &mode, HashTableMode::Sponge)
988                * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Sponge)
989                * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash)
990                * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad);
991
992        let if_mode_is_hash_then_mode_next_is_hash_or_pad =
993            Self::mode_deselector(circuit_builder, &mode, HashTableMode::Hash)
994                * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash)
995                * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad);
996
997        let if_mode_is_pad_then_mode_next_is_pad =
998            Self::mode_deselector(circuit_builder, &mode, HashTableMode::Pad)
999                * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad);
1000
1001        let difference_of_capacity_registers = state_current[tip5::RATE..]
1002            .iter()
1003            .zip_eq(state_next[tip5::RATE..].iter())
1004            .map(|(current, next)| next.clone() - current.clone())
1005            .collect_vec();
1006        let randomized_sum_of_capacity_differences = state_weights[tip5::RATE..]
1007            .iter()
1008            .zip_eq(difference_of_capacity_registers)
1009            .map(|(weight, state_difference)| weight.clone() * state_difference)
1010            .sum::<ConstraintCircuitMonad<_>>();
1011
1012        let capacity_doesnt_change_at_section_start_when_program_hashing_or_absorbing =
1013            Self::round_number_deselector(circuit_builder, &round_number_next, 0)
1014                * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash)
1015                * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad)
1016                * (ci_next.clone() - opcode_sponge_init.clone())
1017                * randomized_sum_of_capacity_differences.clone();
1018
1019        let difference_of_state_registers = state_current
1020            .iter()
1021            .zip_eq(state_next.iter())
1022            .map(|(current, next)| next.clone() - current.clone())
1023            .collect_vec();
1024        let randomized_sum_of_state_differences = state_weights
1025            .iter()
1026            .zip_eq(difference_of_state_registers.iter())
1027            .map(|(weight, state_difference)| weight.clone() * state_difference.clone())
1028            .sum();
1029        let if_round_number_next_is_0_and_ci_next_is_squeeze_then_state_doesnt_change =
1030            Self::round_number_deselector(circuit_builder, &round_number_next, 0)
1031                * Self::instruction_deselector(
1032                    circuit_builder,
1033                    &ci_next,
1034                    Instruction::SpongeSqueeze,
1035                )
1036                * randomized_sum_of_state_differences;
1037
1038        // Evaluation Arguments
1039
1040        // If (and only if) the row number in the next row is 0 and the mode in
1041        // the next row is `hash`, update running evaluation “hash input.”
1042        let running_evaluation_hash_input_remains =
1043            running_evaluation_hash_input_next.clone() - running_evaluation_hash_input.clone();
1044        let tip5_input = state_next[..tip5::RATE].to_owned();
1045        let compressed_row_from_processor = tip5_input
1046            .into_iter()
1047            .zip_eq(state_weights[..tip5::RATE].iter())
1048            .map(|(state, weight)| weight.clone() * state)
1049            .sum();
1050
1051        let running_evaluation_hash_input_updates = running_evaluation_hash_input_next
1052            - hash_input_eval_indeterminate * running_evaluation_hash_input
1053            - compressed_row_from_processor;
1054        let running_evaluation_hash_input_is_updated_correctly =
1055            Self::round_number_deselector(circuit_builder, &round_number_next, 0)
1056                * Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::Hash)
1057                * running_evaluation_hash_input_updates
1058                + round_number_next.clone() * running_evaluation_hash_input_remains.clone()
1059                + Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash)
1060                    * running_evaluation_hash_input_remains;
1061
1062        // If (and only if) the row number in the next row is NUM_ROUNDS and the
1063        // current instruction in the next row corresponds to `hash`, update
1064        // running evaluation “hash digest.”
1065        let round_number_next_is_num_rounds =
1066            round_number_next.clone() - constant(NUM_ROUNDS as u64);
1067        let running_evaluation_hash_digest_remains =
1068            running_evaluation_hash_digest_next.clone() - running_evaluation_hash_digest.clone();
1069        let hash_digest = state_next[..Digest::LEN].to_owned();
1070        let compressed_row_hash_digest = hash_digest
1071            .into_iter()
1072            .zip_eq(state_weights[..Digest::LEN].iter())
1073            .map(|(state, weight)| weight.clone() * state)
1074            .sum();
1075        let running_evaluation_hash_digest_updates = running_evaluation_hash_digest_next
1076            - hash_digest_eval_indeterminate * running_evaluation_hash_digest
1077            - compressed_row_hash_digest;
1078        let running_evaluation_hash_digest_is_updated_correctly =
1079            Self::round_number_deselector(circuit_builder, &round_number_next, NUM_ROUNDS)
1080                * Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::Hash)
1081                * running_evaluation_hash_digest_updates
1082                + round_number_next_is_num_rounds * running_evaluation_hash_digest_remains.clone()
1083                + Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash)
1084                    * running_evaluation_hash_digest_remains;
1085
1086        // The running evaluation for “Sponge” updates correctly.
1087        let compressed_row_next = state_weights[..tip5::RATE]
1088            .iter()
1089            .zip_eq(state_next[..tip5::RATE].iter())
1090            .map(|(weight, st_next)| weight.clone() * st_next.clone())
1091            .sum();
1092        let running_evaluation_sponge_has_accumulated_ci = running_evaluation_sponge_next.clone()
1093            - sponge_indeterminate * running_evaluation_sponge.clone()
1094            - challenge(ChallengeId::HashCIWeight) * ci_next.clone();
1095        let running_evaluation_sponge_has_accumulated_next_row =
1096            running_evaluation_sponge_has_accumulated_ci.clone() - compressed_row_next;
1097        let if_round_no_next_0_and_ci_next_is_spongy_then_running_evaluation_sponge_updates =
1098            Self::round_number_deselector(circuit_builder, &round_number_next, 0)
1099                * (ci_next.clone() - opcode_hash)
1100                * running_evaluation_sponge_has_accumulated_next_row;
1101
1102        let running_evaluation_sponge_remains =
1103            running_evaluation_sponge_next - running_evaluation_sponge;
1104        let if_round_no_next_is_not_0_then_running_evaluation_sponge_remains =
1105            round_number_next.clone() * running_evaluation_sponge_remains.clone();
1106        let if_ci_next_is_not_spongy_then_running_evaluation_sponge_remains = (ci_next.clone()
1107            - opcode_sponge_init)
1108            * (ci_next.clone() - opcode_sponge_absorb)
1109            * (ci_next - opcode_sponge_squeeze)
1110            * running_evaluation_sponge_remains;
1111        let running_evaluation_sponge_is_updated_correctly =
1112            if_round_no_next_0_and_ci_next_is_spongy_then_running_evaluation_sponge_updates
1113                + if_round_no_next_is_not_0_then_running_evaluation_sponge_remains
1114                + if_ci_next_is_not_spongy_then_running_evaluation_sponge_remains;
1115
1116        // program attestation: absorb RATE instructions if in the right mode on
1117        // the right row
1118        let compressed_chunk = state_next[..tip5::RATE]
1119            .iter()
1120            .fold(running_evaluation_initial, |acc, rate_element| {
1121                acc * prepare_chunk_indeterminate.clone() + rate_element.clone()
1122            });
1123        let receive_chunk_running_evaluation_absorbs_chunk_of_instructions =
1124            running_evaluation_receive_chunk_next.clone()
1125                - receive_chunk_indeterminate * running_evaluation_receive_chunk.clone()
1126                - compressed_chunk;
1127        let receive_chunk_running_evaluation_remains =
1128            running_evaluation_receive_chunk_next - running_evaluation_receive_chunk;
1129        let receive_chunk_of_instructions_iff_next_mode_is_prog_hashing_and_next_round_number_is_0 =
1130            Self::round_number_deselector(circuit_builder, &round_number_next, 0)
1131                * Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::ProgramHashing)
1132                * receive_chunk_running_evaluation_absorbs_chunk_of_instructions
1133                + round_number_next * receive_chunk_running_evaluation_remains.clone()
1134                + Self::select_mode(circuit_builder, &mode_next, HashTableMode::ProgramHashing)
1135                    * receive_chunk_running_evaluation_remains;
1136
1137        let constraints = vec![
1138            round_number_is_0_through_4_or_round_number_next_is_0,
1139            next_mode_is_padding_mode_or_round_number_is_num_rounds_or_increments_by_one,
1140            receive_chunk_of_instructions_iff_next_mode_is_prog_hashing_and_next_round_number_is_0,
1141            if_mode_changes_from_program_hashing_then_current_digest_is_expected_program_digest,
1142            if_mode_is_program_hashing_and_next_mode_is_sponge_then_ci_next_is_sponge_init,
1143            if_round_number_is_not_max_and_ci_is_not_sponge_init_then_ci_doesnt_change,
1144            if_round_number_is_not_max_and_ci_is_not_sponge_init_then_mode_doesnt_change,
1145            if_mode_is_sponge_then_mode_next_is_sponge_or_hash_or_pad,
1146            if_mode_is_hash_then_mode_next_is_hash_or_pad,
1147            if_mode_is_pad_then_mode_next_is_pad,
1148            capacity_doesnt_change_at_section_start_when_program_hashing_or_absorbing,
1149            if_round_number_next_is_0_and_ci_next_is_squeeze_then_state_doesnt_change,
1150            running_evaluation_hash_input_is_updated_correctly,
1151            running_evaluation_hash_digest_is_updated_correctly,
1152            running_evaluation_sponge_is_updated_correctly,
1153            Self::cascade_log_derivative_update_circuit(
1154                circuit_builder,
1155                Self::MainColumn::State0HighestLkIn,
1156                Self::MainColumn::State0HighestLkOut,
1157                Self::AuxColumn::CascadeState0HighestClientLogDerivative,
1158            ),
1159            Self::cascade_log_derivative_update_circuit(
1160                circuit_builder,
1161                Self::MainColumn::State0MidHighLkIn,
1162                Self::MainColumn::State0MidHighLkOut,
1163                Self::AuxColumn::CascadeState0MidHighClientLogDerivative,
1164            ),
1165            Self::cascade_log_derivative_update_circuit(
1166                circuit_builder,
1167                Self::MainColumn::State0MidLowLkIn,
1168                Self::MainColumn::State0MidLowLkOut,
1169                Self::AuxColumn::CascadeState0MidLowClientLogDerivative,
1170            ),
1171            Self::cascade_log_derivative_update_circuit(
1172                circuit_builder,
1173                Self::MainColumn::State0LowestLkIn,
1174                Self::MainColumn::State0LowestLkOut,
1175                Self::AuxColumn::CascadeState0LowestClientLogDerivative,
1176            ),
1177            Self::cascade_log_derivative_update_circuit(
1178                circuit_builder,
1179                Self::MainColumn::State1HighestLkIn,
1180                Self::MainColumn::State1HighestLkOut,
1181                Self::AuxColumn::CascadeState1HighestClientLogDerivative,
1182            ),
1183            Self::cascade_log_derivative_update_circuit(
1184                circuit_builder,
1185                Self::MainColumn::State1MidHighLkIn,
1186                Self::MainColumn::State1MidHighLkOut,
1187                Self::AuxColumn::CascadeState1MidHighClientLogDerivative,
1188            ),
1189            Self::cascade_log_derivative_update_circuit(
1190                circuit_builder,
1191                Self::MainColumn::State1MidLowLkIn,
1192                Self::MainColumn::State1MidLowLkOut,
1193                Self::AuxColumn::CascadeState1MidLowClientLogDerivative,
1194            ),
1195            Self::cascade_log_derivative_update_circuit(
1196                circuit_builder,
1197                Self::MainColumn::State1LowestLkIn,
1198                Self::MainColumn::State1LowestLkOut,
1199                Self::AuxColumn::CascadeState1LowestClientLogDerivative,
1200            ),
1201            Self::cascade_log_derivative_update_circuit(
1202                circuit_builder,
1203                Self::MainColumn::State2HighestLkIn,
1204                Self::MainColumn::State2HighestLkOut,
1205                Self::AuxColumn::CascadeState2HighestClientLogDerivative,
1206            ),
1207            Self::cascade_log_derivative_update_circuit(
1208                circuit_builder,
1209                Self::MainColumn::State2MidHighLkIn,
1210                Self::MainColumn::State2MidHighLkOut,
1211                Self::AuxColumn::CascadeState2MidHighClientLogDerivative,
1212            ),
1213            Self::cascade_log_derivative_update_circuit(
1214                circuit_builder,
1215                Self::MainColumn::State2MidLowLkIn,
1216                Self::MainColumn::State2MidLowLkOut,
1217                Self::AuxColumn::CascadeState2MidLowClientLogDerivative,
1218            ),
1219            Self::cascade_log_derivative_update_circuit(
1220                circuit_builder,
1221                Self::MainColumn::State2LowestLkIn,
1222                Self::MainColumn::State2LowestLkOut,
1223                Self::AuxColumn::CascadeState2LowestClientLogDerivative,
1224            ),
1225            Self::cascade_log_derivative_update_circuit(
1226                circuit_builder,
1227                Self::MainColumn::State3HighestLkIn,
1228                Self::MainColumn::State3HighestLkOut,
1229                Self::AuxColumn::CascadeState3HighestClientLogDerivative,
1230            ),
1231            Self::cascade_log_derivative_update_circuit(
1232                circuit_builder,
1233                Self::MainColumn::State3MidHighLkIn,
1234                Self::MainColumn::State3MidHighLkOut,
1235                Self::AuxColumn::CascadeState3MidHighClientLogDerivative,
1236            ),
1237            Self::cascade_log_derivative_update_circuit(
1238                circuit_builder,
1239                Self::MainColumn::State3MidLowLkIn,
1240                Self::MainColumn::State3MidLowLkOut,
1241                Self::AuxColumn::CascadeState3MidLowClientLogDerivative,
1242            ),
1243            Self::cascade_log_derivative_update_circuit(
1244                circuit_builder,
1245                Self::MainColumn::State3LowestLkIn,
1246                Self::MainColumn::State3LowestLkOut,
1247                Self::AuxColumn::CascadeState3LowestClientLogDerivative,
1248            ),
1249        ];
1250
1251        [
1252            constraints,
1253            hash_function_round_correctly_performs_update.to_vec(),
1254        ]
1255        .concat()
1256    }
1257
1258    fn terminal_constraints(
1259        circuit_builder: &ConstraintCircuitBuilder<SingleRowIndicator>,
1260    ) -> Vec<ConstraintCircuitMonad<SingleRowIndicator>> {
1261        let challenge = |c| circuit_builder.challenge(c);
1262        let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b());
1263        let constant = |c: u64| circuit_builder.b_constant(c);
1264        let main_row = |column_idx: Self::MainColumn| {
1265            circuit_builder.input(Main(column_idx.master_main_index()))
1266        };
1267
1268        let mode = main_row(Self::MainColumn::Mode);
1269        let round_number = main_row(Self::MainColumn::RoundNumber);
1270
1271        let compress_program_digest_indeterminate =
1272            challenge(ChallengeId::CompressProgramDigestIndeterminate);
1273        let expected_program_digest = challenge(ChallengeId::CompressedProgramDigest);
1274
1275        let max_round_number = constant(NUM_ROUNDS as u64);
1276
1277        let [state_0, state_1, state_2, state_3] =
1278            Self::re_compose_states_0_through_3_before_lookup(
1279                circuit_builder,
1280                Self::indicate_column_index_in_main_row,
1281            );
1282        let state_4 = main_row(Self::MainColumn::State4);
1283        let program_digest = [state_0, state_1, state_2, state_3, state_4];
1284        let compressed_digest = program_digest.into_iter().fold(
1285            circuit_builder.x_constant(EvalArg::default_initial()),
1286            |acc, digest_element| {
1287                acc * compress_program_digest_indeterminate.clone() + digest_element
1288            },
1289        );
1290        let if_mode_is_program_hashing_then_current_digest_is_expected_program_digest =
1291            Self::mode_deselector(circuit_builder, &mode, HashTableMode::ProgramHashing)
1292                * (compressed_digest - expected_program_digest);
1293
1294        let if_mode_is_not_pad_and_ci_is_not_sponge_init_then_round_number_is_max_round_number =
1295            Self::select_mode(circuit_builder, &mode, HashTableMode::Pad)
1296                * (main_row(Self::MainColumn::CI) - opcode(Instruction::SpongeInit))
1297                * (round_number - max_round_number);
1298
1299        vec![
1300            if_mode_is_program_hashing_then_current_digest_is_expected_program_digest,
1301            if_mode_is_not_pad_and_ci_is_not_sponge_init_then_round_number_is_max_round_number,
1302        ]
1303    }
1304}
1305
1306/// The current “mode” of the Hash Table. The Hash Table can be in one of four
1307/// distinct modes:
1308///
1309/// 1. Hashing the [`Program`][program]. This is part of program attestation.
1310/// 1. Processing all Sponge instructions, _i.e._, `sponge_init`,
1311///    `sponge_absorb`, `sponge_absorb_mem`, and `sponge_squeeze`.
1312/// 1. Processing the `hash` instruction.
1313/// 1. Padding mode.
1314///
1315/// Changing the mode is only possible when the current
1316/// [`RoundNumber`][round_no] is [`NUM_ROUNDS`]. The mode evolves as
1317/// [`ProgramHashing`][prog_hash] → [`Sponge`][sponge] → [`Hash`][hash] →
1318/// [`Pad`][pad]. Once mode [`Pad`][pad] is reached, it is not possible to
1319/// change the mode anymore. Skipping any or all of the modes
1320/// [`Sponge`][sponge], [`Hash`][hash], or [`Pad`][pad] is possible in
1321/// principle:
1322/// - if no Sponge instructions are executed, mode [`Sponge`][sponge] will be
1323///   skipped,
1324/// - if no `hash` instruction is executed, mode [`Hash`][hash] will be skipped,
1325///   and
1326/// - if the Hash Table does not require any padding, mode [`Pad`][pad] will be
1327///   skipped.
1328///
1329/// It is not possible to skip mode [`ProgramHashing`][prog_hash]:
1330/// the [`Program`][program] is always hashed.
1331/// The empty program is not valid since any valid [`Program`][program] must
1332/// execute instruction `halt`.
1333///
1334/// [round_no]: crate::table_column::HashMainColumn::RoundNumber
1335/// [program]: isa::program::Program
1336/// [prog_hash]: HashTableMode::ProgramHashing
1337/// [sponge]: HashTableMode::Sponge
1338/// [hash]: type@HashTableMode::Hash
1339/// [pad]: HashTableMode::Pad
1340#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)]
1341pub enum HashTableMode {
1342    /// The mode in which the [`Program`][program] is hashed. This is part of
1343    /// program attestation.
1344    ///
1345    /// [program]: isa::program::Program
1346    ProgramHashing,
1347
1348    /// The mode in which Sponge instructions, _i.e._, `sponge_init`,
1349    /// `sponge_absorb`, `sponge_absorb_mem`, and `sponge_squeeze`, are
1350    /// processed.
1351    Sponge,
1352
1353    /// The mode in which the `hash` instruction is processed.
1354    Hash,
1355
1356    /// Indicator for padding rows.
1357    Pad,
1358}
1359
1360impl From<HashTableMode> for u32 {
1361    fn from(mode: HashTableMode) -> Self {
1362        match mode {
1363            HashTableMode::ProgramHashing => 1,
1364            HashTableMode::Sponge => 2,
1365            HashTableMode::Hash => 3,
1366            HashTableMode::Pad => 0,
1367        }
1368    }
1369}
1370
1371impl From<HashTableMode> for u64 {
1372    fn from(mode: HashTableMode) -> Self {
1373        let discriminant: u32 = mode.into();
1374        discriminant.into()
1375    }
1376}
1377
1378impl From<HashTableMode> for BFieldElement {
1379    fn from(mode: HashTableMode) -> Self {
1380        let discriminant: u32 = mode.into();
1381        discriminant.into()
1382    }
1383}