1use constraint_circuit::ConstraintCircuitBuilder;
2use constraint_circuit::ConstraintCircuitMonad;
3use constraint_circuit::DualRowIndicator;
4use constraint_circuit::DualRowIndicator::CurrentAux;
5use constraint_circuit::DualRowIndicator::CurrentMain;
6use constraint_circuit::DualRowIndicator::NextAux;
7use constraint_circuit::DualRowIndicator::NextMain;
8use constraint_circuit::InputIndicator;
9use constraint_circuit::SingleRowIndicator;
10use constraint_circuit::SingleRowIndicator::Aux;
11use constraint_circuit::SingleRowIndicator::Main;
12use isa::instruction::Instruction;
13use itertools::Itertools;
14use strum::Display;
15use strum::EnumCount;
16use strum::EnumIter;
17use strum::IntoEnumIterator;
18use twenty_first::prelude::tip5::NUM_ROUNDS;
19use twenty_first::prelude::*;
20
21use crate::AIR;
22use crate::challenge_id::ChallengeId;
23use crate::cross_table_argument::CrossTableArg;
24use crate::cross_table_argument::EvalArg;
25use crate::cross_table_argument::LookupArg;
26use crate::table_column::MasterAuxColumn;
27use crate::table_column::MasterMainColumn;
28
29pub const MONTGOMERY_MODULUS: BFieldElement =
30 BFieldElement::new(((1_u128 << 64) % BFieldElement::P as u128) as u64);
31
32const POWER_MAP_EXPONENT: u64 = 7;
33const NUM_ROUND_CONSTANTS: usize = tip5::STATE_SIZE;
34
35pub const PERMUTATION_TRACE_LENGTH: usize = NUM_ROUNDS + 1;
36
37pub type PermutationTrace = [[BFieldElement; tip5::STATE_SIZE]; PERMUTATION_TRACE_LENGTH];
38
39#[derive(Debug, Copy, Clone, Eq, PartialEq)]
40pub struct HashTable;
41
42impl crate::private::Seal for HashTable {}
43
44type MainColumn = <HashTable as AIR>::MainColumn;
45type AuxColumn = <HashTable as AIR>::AuxColumn;
46
47impl HashTable {
48 const fn mds_matrix_entry(row_idx: usize, col_idx: usize) -> BFieldElement {
50 assert!(row_idx < tip5::STATE_SIZE);
51 assert!(col_idx < tip5::STATE_SIZE);
52 let index_in_matrix_defining_column =
53 (tip5::STATE_SIZE + row_idx - col_idx) % tip5::STATE_SIZE;
54 let mds_matrix_entry = tip5::MDS_MATRIX_FIRST_COLUMN[index_in_matrix_defining_column];
55 BFieldElement::new(mds_matrix_entry as u64)
56 }
57
58 pub fn tip5_round_constants_by_round_number(r: usize) -> [BFieldElement; NUM_ROUND_CONSTANTS] {
61 if r >= NUM_ROUNDS {
62 return bfe_array![0; NUM_ROUND_CONSTANTS];
63 }
64
65 let range_start = NUM_ROUND_CONSTANTS * r;
66 let range_end = NUM_ROUND_CONSTANTS * (r + 1);
67 tip5::ROUND_CONSTANTS[range_start..range_end]
68 .try_into()
69 .unwrap()
70 }
71
72 fn re_compose_16_bit_limbs<II: InputIndicator>(
84 circuit_builder: &ConstraintCircuitBuilder<II>,
85 highest: ConstraintCircuitMonad<II>,
86 mid_high: ConstraintCircuitMonad<II>,
87 mid_low: ConstraintCircuitMonad<II>,
88 lowest: ConstraintCircuitMonad<II>,
89 ) -> ConstraintCircuitMonad<II> {
90 let constant = |c: u64| circuit_builder.b_constant(c);
91 let montgomery_modulus_inv = circuit_builder.b_constant(MONTGOMERY_MODULUS.inverse());
92
93 let sum_of_shifted_limbs = highest * constant(1 << 48)
94 + mid_high * constant(1 << 32)
95 + mid_low * constant(1 << 16)
96 + lowest;
97 sum_of_shifted_limbs * montgomery_modulus_inv
98 }
99
100 fn round_number_deselector<II: InputIndicator>(
104 circuit_builder: &ConstraintCircuitBuilder<II>,
105 round_number_circuit_node: &ConstraintCircuitMonad<II>,
106 round_number_to_deselect: usize,
107 ) -> ConstraintCircuitMonad<II> {
108 assert!(
109 round_number_to_deselect <= NUM_ROUNDS,
110 "Round number must be in [0, {NUM_ROUNDS}] but got {round_number_to_deselect}."
111 );
112 let constant = |c: u64| circuit_builder.b_constant(c);
113
114 let first_factor = match round_number_to_deselect {
116 0 => constant(1),
117 _ => round_number_circuit_node.clone(),
118 };
119 (1..=NUM_ROUNDS)
120 .filter(|&r| r != round_number_to_deselect)
121 .map(|r| round_number_circuit_node.clone() - constant(r as u64))
122 .fold(first_factor, |a, b| a * b)
123 }
124
125 fn select_mode<II: InputIndicator>(
128 circuit_builder: &ConstraintCircuitBuilder<II>,
129 mode_circuit_node: &ConstraintCircuitMonad<II>,
130 mode_to_select: HashTableMode,
131 ) -> ConstraintCircuitMonad<II> {
132 mode_circuit_node.clone() - circuit_builder.b_constant(mode_to_select)
133 }
134
135 fn mode_deselector<II: InputIndicator>(
138 circuit_builder: &ConstraintCircuitBuilder<II>,
139 mode_circuit_node: &ConstraintCircuitMonad<II>,
140 mode_to_deselect: HashTableMode,
141 ) -> ConstraintCircuitMonad<II> {
142 let constant = |c: u64| circuit_builder.b_constant(c);
143 HashTableMode::iter()
144 .filter(|&mode| mode != mode_to_deselect)
145 .map(|mode| mode_circuit_node.clone() - constant(mode.into()))
146 .fold(constant(1), |accumulator, factor| accumulator * factor)
147 }
148
149 fn instruction_deselector<II: InputIndicator>(
150 circuit_builder: &ConstraintCircuitBuilder<II>,
151 current_instruction_node: &ConstraintCircuitMonad<II>,
152 instruction_to_deselect: Instruction,
153 ) -> ConstraintCircuitMonad<II> {
154 let constant = |c: u64| circuit_builder.b_constant(c);
155 let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b());
156 let relevant_instructions = [
157 Instruction::Hash,
158 Instruction::SpongeInit,
159 Instruction::SpongeAbsorb,
160 Instruction::SpongeSqueeze,
161 ];
162 assert!(relevant_instructions.contains(&instruction_to_deselect));
163
164 relevant_instructions
165 .iter()
166 .filter(|&instruction| instruction != &instruction_to_deselect)
167 .map(|&instruction| current_instruction_node.clone() - opcode(instruction))
168 .fold(constant(1), |accumulator, factor| accumulator * factor)
169 }
170
171 pub fn round_constant_column_by_index(index: usize) -> MainColumn {
179 match index {
180 0 => MainColumn::Constant0,
181 1 => MainColumn::Constant1,
182 2 => MainColumn::Constant2,
183 3 => MainColumn::Constant3,
184 4 => MainColumn::Constant4,
185 5 => MainColumn::Constant5,
186 6 => MainColumn::Constant6,
187 7 => MainColumn::Constant7,
188 8 => MainColumn::Constant8,
189 9 => MainColumn::Constant9,
190 10 => MainColumn::Constant10,
191 11 => MainColumn::Constant11,
192 12 => MainColumn::Constant12,
193 13 => MainColumn::Constant13,
194 14 => MainColumn::Constant14,
195 15 => MainColumn::Constant15,
196 _ => panic!("invalid constant column index"),
197 }
198 }
199
200 fn state_column_by_index(index: usize) -> MainColumn {
211 match index {
212 4 => MainColumn::State4,
213 5 => MainColumn::State5,
214 6 => MainColumn::State6,
215 7 => MainColumn::State7,
216 8 => MainColumn::State8,
217 9 => MainColumn::State9,
218 10 => MainColumn::State10,
219 11 => MainColumn::State11,
220 12 => MainColumn::State12,
221 13 => MainColumn::State13,
222 14 => MainColumn::State14,
223 15 => MainColumn::State15,
224 _ => panic!("invalid state column index"),
225 }
226 }
227
228 fn indicate_column_index_in_main_row(column: MainColumn) -> SingleRowIndicator {
229 Main(column.master_main_index())
230 }
231
232 fn indicate_column_index_in_current_main_row(column: MainColumn) -> DualRowIndicator {
233 CurrentMain(column.master_main_index())
234 }
235
236 fn indicate_column_index_in_next_main_row(column: MainColumn) -> DualRowIndicator {
237 NextMain(column.master_main_index())
238 }
239
240 fn re_compose_states_0_through_3_before_lookup<II: InputIndicator>(
241 circuit_builder: &ConstraintCircuitBuilder<II>,
242 main_row_to_input_indicator: fn(MainColumn) -> II,
243 ) -> [ConstraintCircuitMonad<II>; 4] {
244 let input = |input_indicator: II| circuit_builder.input(input_indicator);
245 let state_0 = Self::re_compose_16_bit_limbs(
246 circuit_builder,
247 input(main_row_to_input_indicator(MainColumn::State0HighestLkIn)),
248 input(main_row_to_input_indicator(MainColumn::State0MidHighLkIn)),
249 input(main_row_to_input_indicator(MainColumn::State0MidLowLkIn)),
250 input(main_row_to_input_indicator(MainColumn::State0LowestLkIn)),
251 );
252 let state_1 = Self::re_compose_16_bit_limbs(
253 circuit_builder,
254 input(main_row_to_input_indicator(MainColumn::State1HighestLkIn)),
255 input(main_row_to_input_indicator(MainColumn::State1MidHighLkIn)),
256 input(main_row_to_input_indicator(MainColumn::State1MidLowLkIn)),
257 input(main_row_to_input_indicator(MainColumn::State1LowestLkIn)),
258 );
259 let state_2 = Self::re_compose_16_bit_limbs(
260 circuit_builder,
261 input(main_row_to_input_indicator(MainColumn::State2HighestLkIn)),
262 input(main_row_to_input_indicator(MainColumn::State2MidHighLkIn)),
263 input(main_row_to_input_indicator(MainColumn::State2MidLowLkIn)),
264 input(main_row_to_input_indicator(MainColumn::State2LowestLkIn)),
265 );
266 let state_3 = Self::re_compose_16_bit_limbs(
267 circuit_builder,
268 input(main_row_to_input_indicator(MainColumn::State3HighestLkIn)),
269 input(main_row_to_input_indicator(MainColumn::State3MidHighLkIn)),
270 input(main_row_to_input_indicator(MainColumn::State3MidLowLkIn)),
271 input(main_row_to_input_indicator(MainColumn::State3LowestLkIn)),
272 );
273 [state_0, state_1, state_2, state_3]
274 }
275
276 fn tip5_constraints_as_circuits(
277 circuit_builder: &ConstraintCircuitBuilder<DualRowIndicator>,
278 ) -> (
279 [ConstraintCircuitMonad<DualRowIndicator>; tip5::STATE_SIZE],
280 [ConstraintCircuitMonad<DualRowIndicator>; tip5::STATE_SIZE],
281 ) {
282 let constant = |c: u64| circuit_builder.b_constant(c);
283 let b_constant = |c| circuit_builder.b_constant(c);
284 let current_main_row = |column_idx: MainColumn| {
285 circuit_builder.input(CurrentMain(column_idx.master_main_index()))
286 };
287 let next_main_row = |column_idx: MainColumn| {
288 circuit_builder.input(NextMain(column_idx.master_main_index()))
289 };
290
291 let state_0_after_lookup = Self::re_compose_16_bit_limbs(
292 circuit_builder,
293 current_main_row(MainColumn::State0HighestLkOut),
294 current_main_row(MainColumn::State0MidHighLkOut),
295 current_main_row(MainColumn::State0MidLowLkOut),
296 current_main_row(MainColumn::State0LowestLkOut),
297 );
298 let state_1_after_lookup = Self::re_compose_16_bit_limbs(
299 circuit_builder,
300 current_main_row(MainColumn::State1HighestLkOut),
301 current_main_row(MainColumn::State1MidHighLkOut),
302 current_main_row(MainColumn::State1MidLowLkOut),
303 current_main_row(MainColumn::State1LowestLkOut),
304 );
305 let state_2_after_lookup = Self::re_compose_16_bit_limbs(
306 circuit_builder,
307 current_main_row(MainColumn::State2HighestLkOut),
308 current_main_row(MainColumn::State2MidHighLkOut),
309 current_main_row(MainColumn::State2MidLowLkOut),
310 current_main_row(MainColumn::State2LowestLkOut),
311 );
312 let state_3_after_lookup = Self::re_compose_16_bit_limbs(
313 circuit_builder,
314 current_main_row(MainColumn::State3HighestLkOut),
315 current_main_row(MainColumn::State3MidHighLkOut),
316 current_main_row(MainColumn::State3MidLowLkOut),
317 current_main_row(MainColumn::State3LowestLkOut),
318 );
319
320 let state_part_before_power_map: [_; tip5::STATE_SIZE - tip5::NUM_SPLIT_AND_LOOKUP] = [
321 MainColumn::State4,
322 MainColumn::State5,
323 MainColumn::State6,
324 MainColumn::State7,
325 MainColumn::State8,
326 MainColumn::State9,
327 MainColumn::State10,
328 MainColumn::State11,
329 MainColumn::State12,
330 MainColumn::State13,
331 MainColumn::State14,
332 MainColumn::State15,
333 ]
334 .map(current_main_row);
335
336 let state_part_after_power_map = {
337 let mut exponentiation_accumulator = state_part_before_power_map.clone();
338 for _ in 1..POWER_MAP_EXPONENT {
339 for (i, state) in exponentiation_accumulator.iter_mut().enumerate() {
340 *state = state.clone() * state_part_before_power_map[i].clone();
341 }
342 }
343 exponentiation_accumulator
344 };
345
346 let state_after_s_box_application = [
347 state_0_after_lookup,
348 state_1_after_lookup,
349 state_2_after_lookup,
350 state_3_after_lookup,
351 state_part_after_power_map[0].clone(),
352 state_part_after_power_map[1].clone(),
353 state_part_after_power_map[2].clone(),
354 state_part_after_power_map[3].clone(),
355 state_part_after_power_map[4].clone(),
356 state_part_after_power_map[5].clone(),
357 state_part_after_power_map[6].clone(),
358 state_part_after_power_map[7].clone(),
359 state_part_after_power_map[8].clone(),
360 state_part_after_power_map[9].clone(),
361 state_part_after_power_map[10].clone(),
362 state_part_after_power_map[11].clone(),
363 ];
364
365 let mut state_after_matrix_multiplication = vec![constant(0); tip5::STATE_SIZE];
366 for (row_idx, acc) in state_after_matrix_multiplication.iter_mut().enumerate() {
367 for (col_idx, state) in state_after_s_box_application.iter().enumerate() {
368 let matrix_entry = b_constant(Self::mds_matrix_entry(row_idx, col_idx));
369 *acc = acc.clone() + matrix_entry * state.clone();
370 }
371 }
372
373 let round_constants: [_; tip5::STATE_SIZE] = [
374 MainColumn::Constant0,
375 MainColumn::Constant1,
376 MainColumn::Constant2,
377 MainColumn::Constant3,
378 MainColumn::Constant4,
379 MainColumn::Constant5,
380 MainColumn::Constant6,
381 MainColumn::Constant7,
382 MainColumn::Constant8,
383 MainColumn::Constant9,
384 MainColumn::Constant10,
385 MainColumn::Constant11,
386 MainColumn::Constant12,
387 MainColumn::Constant13,
388 MainColumn::Constant14,
389 MainColumn::Constant15,
390 ]
391 .map(current_main_row);
392
393 let state_after_round_constant_addition = state_after_matrix_multiplication
394 .into_iter()
395 .zip_eq(round_constants)
396 .map(|(st, rndc)| st + rndc)
397 .collect_vec();
398
399 let [state_0_next, state_1_next, state_2_next, state_3_next] =
400 Self::re_compose_states_0_through_3_before_lookup(
401 circuit_builder,
402 Self::indicate_column_index_in_next_main_row,
403 );
404 let state_next = [
405 state_0_next,
406 state_1_next,
407 state_2_next,
408 state_3_next,
409 next_main_row(MainColumn::State4),
410 next_main_row(MainColumn::State5),
411 next_main_row(MainColumn::State6),
412 next_main_row(MainColumn::State7),
413 next_main_row(MainColumn::State8),
414 next_main_row(MainColumn::State9),
415 next_main_row(MainColumn::State10),
416 next_main_row(MainColumn::State11),
417 next_main_row(MainColumn::State12),
418 next_main_row(MainColumn::State13),
419 next_main_row(MainColumn::State14),
420 next_main_row(MainColumn::State15),
421 ];
422
423 let round_number_next = next_main_row(MainColumn::RoundNumber);
424 let hash_function_round_correctly_performs_update = state_after_round_constant_addition
425 .into_iter()
426 .zip_eq(state_next.clone())
427 .map(|(state_element, state_element_next)| {
428 round_number_next.clone() * (state_element - state_element_next)
429 })
430 .collect_vec()
431 .try_into()
432 .unwrap();
433
434 (state_next, hash_function_round_correctly_performs_update)
435 }
436
437 fn cascade_log_derivative_update_circuit(
438 circuit_builder: &ConstraintCircuitBuilder<DualRowIndicator>,
439 look_in_column: MainColumn,
440 look_out_column: MainColumn,
441 cascade_log_derivative_column: AuxColumn,
442 ) -> ConstraintCircuitMonad<DualRowIndicator> {
443 let challenge = |c| circuit_builder.challenge(c);
444 let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b());
445 let constant = |c: u32| circuit_builder.b_constant(c);
446 let next_main_row = |column_idx: MainColumn| {
447 circuit_builder.input(NextMain(column_idx.master_main_index()))
448 };
449 let current_aux_row = |column_idx: AuxColumn| {
450 circuit_builder.input(CurrentAux(column_idx.master_aux_index()))
451 };
452 let next_aux_row =
453 |column_idx: AuxColumn| circuit_builder.input(NextAux(column_idx.master_aux_index()));
454
455 let cascade_indeterminate = challenge(ChallengeId::HashCascadeLookupIndeterminate);
456 let look_in_weight = challenge(ChallengeId::HashCascadeLookInWeight);
457 let look_out_weight = challenge(ChallengeId::HashCascadeLookOutWeight);
458
459 let ci_next = next_main_row(MainColumn::CI);
460 let mode_next = next_main_row(MainColumn::Mode);
461 let round_number_next = next_main_row(MainColumn::RoundNumber);
462 let cascade_log_derivative = current_aux_row(cascade_log_derivative_column);
463 let cascade_log_derivative_next = next_aux_row(cascade_log_derivative_column);
464
465 let compressed_row = look_in_weight * next_main_row(look_in_column)
466 + look_out_weight * next_main_row(look_out_column);
467
468 let cascade_log_derivative_remains =
469 cascade_log_derivative_next.clone() - cascade_log_derivative.clone();
470 let cascade_log_derivative_updates = (cascade_log_derivative_next - cascade_log_derivative)
471 * (cascade_indeterminate - compressed_row)
472 - constant(1);
473
474 let next_row_is_padding_row_or_round_number_next_is_max_or_ci_next_is_sponge_init =
475 Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad)
476 * (round_number_next.clone() - constant(NUM_ROUNDS as u32))
477 * (ci_next.clone() - opcode(Instruction::SpongeInit));
478 let round_number_next_is_not_num_rounds =
479 Self::round_number_deselector(circuit_builder, &round_number_next, NUM_ROUNDS);
480 let current_instruction_next_is_not_sponge_init =
481 Self::instruction_deselector(circuit_builder, &ci_next, Instruction::SpongeInit);
482
483 next_row_is_padding_row_or_round_number_next_is_max_or_ci_next_is_sponge_init
484 * cascade_log_derivative_updates
485 + round_number_next_is_not_num_rounds * cascade_log_derivative_remains.clone()
486 + current_instruction_next_is_not_sponge_init * cascade_log_derivative_remains
487 }
488}
489
490impl AIR for HashTable {
491 type MainColumn = crate::table_column::HashMainColumn;
492 type AuxColumn = crate::table_column::HashAuxColumn;
493
494 fn initial_constraints(
495 circuit_builder: &ConstraintCircuitBuilder<SingleRowIndicator>,
496 ) -> Vec<ConstraintCircuitMonad<SingleRowIndicator>> {
497 let challenge = |c| circuit_builder.challenge(c);
498 let constant = |c: u64| circuit_builder.b_constant(c);
499
500 let main_row =
501 |column: Self::MainColumn| circuit_builder.input(Main(column.master_main_index()));
502 let aux_row =
503 |column: Self::AuxColumn| circuit_builder.input(Aux(column.master_aux_index()));
504
505 let running_evaluation_initial = circuit_builder.x_constant(EvalArg::default_initial());
506 let lookup_arg_default_initial = circuit_builder.x_constant(LookupArg::default_initial());
507
508 let mode = main_row(Self::MainColumn::Mode);
509 let running_evaluation_hash_input = aux_row(Self::AuxColumn::HashInputRunningEvaluation);
510 let running_evaluation_hash_digest = aux_row(Self::AuxColumn::HashDigestRunningEvaluation);
511 let running_evaluation_sponge = aux_row(Self::AuxColumn::SpongeRunningEvaluation);
512 let running_evaluation_receive_chunk =
513 aux_row(Self::AuxColumn::ReceiveChunkRunningEvaluation);
514
515 let cascade_indeterminate = challenge(ChallengeId::HashCascadeLookupIndeterminate);
516 let look_in_weight = challenge(ChallengeId::HashCascadeLookInWeight);
517 let look_out_weight = challenge(ChallengeId::HashCascadeLookOutWeight);
518 let prepare_chunk_indeterminate =
519 challenge(ChallengeId::ProgramAttestationPrepareChunkIndeterminate);
520 let receive_chunk_indeterminate =
521 challenge(ChallengeId::ProgramAttestationSendChunkIndeterminate);
522
523 let [state_0, state_1, state_2, state_3] =
526 Self::re_compose_states_0_through_3_before_lookup(
527 circuit_builder,
528 Self::indicate_column_index_in_main_row,
529 );
530 let state_rate_part: [_; tip5::RATE] = [
531 state_0,
532 state_1,
533 state_2,
534 state_3,
535 main_row(Self::MainColumn::State4),
536 main_row(Self::MainColumn::State5),
537 main_row(Self::MainColumn::State6),
538 main_row(Self::MainColumn::State7),
539 main_row(Self::MainColumn::State8),
540 main_row(Self::MainColumn::State9),
541 ];
542 let compressed_chunk = state_rate_part
543 .into_iter()
544 .fold(running_evaluation_initial.clone(), |acc, state_element| {
545 acc * prepare_chunk_indeterminate.clone() + state_element
546 });
547 let running_evaluation_receive_chunk_is_initialized_correctly =
548 running_evaluation_receive_chunk
549 - receive_chunk_indeterminate * running_evaluation_initial.clone()
550 - compressed_chunk;
551
552 let cascade_log_derivative_init_circuit =
555 |look_in_column, look_out_column, cascade_log_derivative_column| {
556 let look_in = main_row(look_in_column);
557 let look_out = main_row(look_out_column);
558 let compressed_row =
559 look_in_weight.clone() * look_in + look_out_weight.clone() * look_out;
560 let cascade_log_derivative = aux_row(cascade_log_derivative_column);
561 (cascade_log_derivative - lookup_arg_default_initial.clone())
562 * (cascade_indeterminate.clone() - compressed_row)
563 - constant(1)
564 };
565
566 let mode_is_program_hashing =
568 Self::select_mode(circuit_builder, &mode, HashTableMode::ProgramHashing);
569 let round_number_is_0 = main_row(Self::MainColumn::RoundNumber);
570 let running_evaluation_hash_input_is_default_initial =
571 running_evaluation_hash_input - running_evaluation_initial.clone();
572 let running_evaluation_hash_digest_is_default_initial =
573 running_evaluation_hash_digest - running_evaluation_initial.clone();
574 let running_evaluation_sponge_is_default_initial =
575 running_evaluation_sponge - running_evaluation_initial;
576
577 vec![
578 mode_is_program_hashing,
579 round_number_is_0,
580 running_evaluation_hash_input_is_default_initial,
581 running_evaluation_hash_digest_is_default_initial,
582 running_evaluation_sponge_is_default_initial,
583 running_evaluation_receive_chunk_is_initialized_correctly,
584 cascade_log_derivative_init_circuit(
585 Self::MainColumn::State0HighestLkIn,
586 Self::MainColumn::State0HighestLkOut,
587 Self::AuxColumn::CascadeState0HighestClientLogDerivative,
588 ),
589 cascade_log_derivative_init_circuit(
590 Self::MainColumn::State0MidHighLkIn,
591 Self::MainColumn::State0MidHighLkOut,
592 Self::AuxColumn::CascadeState0MidHighClientLogDerivative,
593 ),
594 cascade_log_derivative_init_circuit(
595 Self::MainColumn::State0MidLowLkIn,
596 Self::MainColumn::State0MidLowLkOut,
597 Self::AuxColumn::CascadeState0MidLowClientLogDerivative,
598 ),
599 cascade_log_derivative_init_circuit(
600 Self::MainColumn::State0LowestLkIn,
601 Self::MainColumn::State0LowestLkOut,
602 Self::AuxColumn::CascadeState0LowestClientLogDerivative,
603 ),
604 cascade_log_derivative_init_circuit(
605 Self::MainColumn::State1HighestLkIn,
606 Self::MainColumn::State1HighestLkOut,
607 Self::AuxColumn::CascadeState1HighestClientLogDerivative,
608 ),
609 cascade_log_derivative_init_circuit(
610 Self::MainColumn::State1MidHighLkIn,
611 Self::MainColumn::State1MidHighLkOut,
612 Self::AuxColumn::CascadeState1MidHighClientLogDerivative,
613 ),
614 cascade_log_derivative_init_circuit(
615 Self::MainColumn::State1MidLowLkIn,
616 Self::MainColumn::State1MidLowLkOut,
617 Self::AuxColumn::CascadeState1MidLowClientLogDerivative,
618 ),
619 cascade_log_derivative_init_circuit(
620 Self::MainColumn::State1LowestLkIn,
621 Self::MainColumn::State1LowestLkOut,
622 Self::AuxColumn::CascadeState1LowestClientLogDerivative,
623 ),
624 cascade_log_derivative_init_circuit(
625 Self::MainColumn::State2HighestLkIn,
626 Self::MainColumn::State2HighestLkOut,
627 Self::AuxColumn::CascadeState2HighestClientLogDerivative,
628 ),
629 cascade_log_derivative_init_circuit(
630 Self::MainColumn::State2MidHighLkIn,
631 Self::MainColumn::State2MidHighLkOut,
632 Self::AuxColumn::CascadeState2MidHighClientLogDerivative,
633 ),
634 cascade_log_derivative_init_circuit(
635 Self::MainColumn::State2MidLowLkIn,
636 Self::MainColumn::State2MidLowLkOut,
637 Self::AuxColumn::CascadeState2MidLowClientLogDerivative,
638 ),
639 cascade_log_derivative_init_circuit(
640 Self::MainColumn::State2LowestLkIn,
641 Self::MainColumn::State2LowestLkOut,
642 Self::AuxColumn::CascadeState2LowestClientLogDerivative,
643 ),
644 cascade_log_derivative_init_circuit(
645 Self::MainColumn::State3HighestLkIn,
646 Self::MainColumn::State3HighestLkOut,
647 Self::AuxColumn::CascadeState3HighestClientLogDerivative,
648 ),
649 cascade_log_derivative_init_circuit(
650 Self::MainColumn::State3MidHighLkIn,
651 Self::MainColumn::State3MidHighLkOut,
652 Self::AuxColumn::CascadeState3MidHighClientLogDerivative,
653 ),
654 cascade_log_derivative_init_circuit(
655 Self::MainColumn::State3MidLowLkIn,
656 Self::MainColumn::State3MidLowLkOut,
657 Self::AuxColumn::CascadeState3MidLowClientLogDerivative,
658 ),
659 cascade_log_derivative_init_circuit(
660 Self::MainColumn::State3LowestLkIn,
661 Self::MainColumn::State3LowestLkOut,
662 Self::AuxColumn::CascadeState3LowestClientLogDerivative,
663 ),
664 ]
665 }
666
667 fn consistency_constraints(
668 circuit_builder: &ConstraintCircuitBuilder<SingleRowIndicator>,
669 ) -> Vec<ConstraintCircuitMonad<SingleRowIndicator>> {
670 let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b());
671 let constant = |c: u64| circuit_builder.b_constant(c);
672 let main_row = |column_id: Self::MainColumn| {
673 circuit_builder.input(Main(column_id.master_main_index()))
674 };
675
676 let mode = main_row(Self::MainColumn::Mode);
677 let ci = main_row(Self::MainColumn::CI);
678 let round_number = main_row(Self::MainColumn::RoundNumber);
679
680 let ci_is_hash = ci.clone() - opcode(Instruction::Hash);
681 let ci_is_sponge_init = ci.clone() - opcode(Instruction::SpongeInit);
682 let ci_is_sponge_absorb = ci.clone() - opcode(Instruction::SpongeAbsorb);
683 let ci_is_sponge_squeeze = ci - opcode(Instruction::SpongeSqueeze);
684
685 let mode_is_not_hash = Self::mode_deselector(circuit_builder, &mode, HashTableMode::Hash);
686 let round_number_is_not_0 =
687 Self::round_number_deselector(circuit_builder, &round_number, 0);
688
689 let mode_is_a_valid_mode =
690 Self::mode_deselector(circuit_builder, &mode, HashTableMode::Pad)
691 * Self::select_mode(circuit_builder, &mode, HashTableMode::Pad);
692
693 let if_mode_is_not_sponge_then_ci_is_hash =
694 Self::select_mode(circuit_builder, &mode, HashTableMode::Sponge) * ci_is_hash.clone();
695
696 let if_mode_is_sponge_then_ci_is_a_sponge_instruction =
697 Self::mode_deselector(circuit_builder, &mode, HashTableMode::Sponge)
698 * ci_is_sponge_init
699 * ci_is_sponge_absorb.clone()
700 * ci_is_sponge_squeeze.clone();
701
702 let if_padding_mode_then_round_number_is_0 =
703 Self::mode_deselector(circuit_builder, &mode, HashTableMode::Pad)
704 * round_number.clone();
705
706 let if_ci_is_sponge_init_then_ = ci_is_hash * ci_is_sponge_absorb * ci_is_sponge_squeeze;
707 let if_ci_is_sponge_init_then_round_number_is_0 =
708 if_ci_is_sponge_init_then_.clone() * round_number.clone();
709
710 let if_ci_is_sponge_init_then_rate_is_0 = (10..=15).map(|state_index| {
711 let state_element = main_row(Self::state_column_by_index(state_index));
712 if_ci_is_sponge_init_then_.clone() * state_element
713 });
714
715 let if_mode_is_hash_and_round_no_is_0_then_ = round_number_is_not_0 * mode_is_not_hash;
716 let if_mode_is_hash_and_round_no_is_0_then_states_10_through_15_are_1 =
717 (10..=15).map(|state_index| {
718 let state_element = main_row(Self::state_column_by_index(state_index));
719 if_mode_is_hash_and_round_no_is_0_then_.clone() * (state_element - constant(1))
720 });
721
722 let one = constant(1);
724 let two_pow_16 = constant(1 << 16);
725 let two_pow_32 = constant(1 << 32);
726 let state_0_hi_limbs_minus_2_pow_32 = two_pow_32.clone()
727 - one.clone()
728 - main_row(Self::MainColumn::State0HighestLkIn) * two_pow_16.clone()
729 - main_row(Self::MainColumn::State0MidHighLkIn);
730 let state_1_hi_limbs_minus_2_pow_32 = two_pow_32.clone()
731 - one.clone()
732 - main_row(Self::MainColumn::State1HighestLkIn) * two_pow_16.clone()
733 - main_row(Self::MainColumn::State1MidHighLkIn);
734 let state_2_hi_limbs_minus_2_pow_32 = two_pow_32.clone()
735 - one.clone()
736 - main_row(Self::MainColumn::State2HighestLkIn) * two_pow_16.clone()
737 - main_row(Self::MainColumn::State2MidHighLkIn);
738 let state_3_hi_limbs_minus_2_pow_32 = two_pow_32
739 - one.clone()
740 - main_row(Self::MainColumn::State3HighestLkIn) * two_pow_16.clone()
741 - main_row(Self::MainColumn::State3MidHighLkIn);
742
743 let state_0_hi_limbs_inv = main_row(Self::MainColumn::State0Inv);
744 let state_1_hi_limbs_inv = main_row(Self::MainColumn::State1Inv);
745 let state_2_hi_limbs_inv = main_row(Self::MainColumn::State2Inv);
746 let state_3_hi_limbs_inv = main_row(Self::MainColumn::State3Inv);
747
748 let state_0_hi_limbs_are_not_all_1s =
749 state_0_hi_limbs_minus_2_pow_32.clone() * state_0_hi_limbs_inv.clone() - one.clone();
750 let state_1_hi_limbs_are_not_all_1s =
751 state_1_hi_limbs_minus_2_pow_32.clone() * state_1_hi_limbs_inv.clone() - one.clone();
752 let state_2_hi_limbs_are_not_all_1s =
753 state_2_hi_limbs_minus_2_pow_32.clone() * state_2_hi_limbs_inv.clone() - one.clone();
754 let state_3_hi_limbs_are_not_all_1s =
755 state_3_hi_limbs_minus_2_pow_32.clone() * state_3_hi_limbs_inv.clone() - one;
756
757 let state_0_hi_limbs_inv_is_inv_or_is_zero =
758 state_0_hi_limbs_are_not_all_1s.clone() * state_0_hi_limbs_inv;
759 let state_1_hi_limbs_inv_is_inv_or_is_zero =
760 state_1_hi_limbs_are_not_all_1s.clone() * state_1_hi_limbs_inv;
761 let state_2_hi_limbs_inv_is_inv_or_is_zero =
762 state_2_hi_limbs_are_not_all_1s.clone() * state_2_hi_limbs_inv;
763 let state_3_hi_limbs_inv_is_inv_or_is_zero =
764 state_3_hi_limbs_are_not_all_1s.clone() * state_3_hi_limbs_inv;
765
766 let state_0_hi_limbs_inv_is_inv_or_state_0_hi_limbs_is_zero =
767 state_0_hi_limbs_are_not_all_1s.clone() * state_0_hi_limbs_minus_2_pow_32;
768 let state_1_hi_limbs_inv_is_inv_or_state_1_hi_limbs_is_zero =
769 state_1_hi_limbs_are_not_all_1s.clone() * state_1_hi_limbs_minus_2_pow_32;
770 let state_2_hi_limbs_inv_is_inv_or_state_2_hi_limbs_is_zero =
771 state_2_hi_limbs_are_not_all_1s.clone() * state_2_hi_limbs_minus_2_pow_32;
772 let state_3_hi_limbs_inv_is_inv_or_state_3_hi_limbs_is_zero =
773 state_3_hi_limbs_are_not_all_1s.clone() * state_3_hi_limbs_minus_2_pow_32;
774
775 let state_0_lo_limbs = main_row(Self::MainColumn::State0MidLowLkIn) * two_pow_16.clone()
777 + main_row(Self::MainColumn::State0LowestLkIn);
778 let state_1_lo_limbs = main_row(Self::MainColumn::State1MidLowLkIn) * two_pow_16.clone()
779 + main_row(Self::MainColumn::State1LowestLkIn);
780 let state_2_lo_limbs = main_row(Self::MainColumn::State2MidLowLkIn) * two_pow_16.clone()
781 + main_row(Self::MainColumn::State2LowestLkIn);
782 let state_3_lo_limbs = main_row(Self::MainColumn::State3MidLowLkIn) * two_pow_16
783 + main_row(Self::MainColumn::State3LowestLkIn);
784
785 let if_state_0_hi_limbs_are_all_1_then_state_0_lo_limbs_are_all_0 =
786 state_0_hi_limbs_are_not_all_1s * state_0_lo_limbs;
787 let if_state_1_hi_limbs_are_all_1_then_state_1_lo_limbs_are_all_0 =
788 state_1_hi_limbs_are_not_all_1s * state_1_lo_limbs;
789 let if_state_2_hi_limbs_are_all_1_then_state_2_lo_limbs_are_all_0 =
790 state_2_hi_limbs_are_not_all_1s * state_2_lo_limbs;
791 let if_state_3_hi_limbs_are_all_1_then_state_3_lo_limbs_are_all_0 =
792 state_3_hi_limbs_are_not_all_1s * state_3_lo_limbs;
793
794 let mut constraints = vec![
795 mode_is_a_valid_mode,
796 if_mode_is_not_sponge_then_ci_is_hash,
797 if_mode_is_sponge_then_ci_is_a_sponge_instruction,
798 if_padding_mode_then_round_number_is_0,
799 if_ci_is_sponge_init_then_round_number_is_0,
800 state_0_hi_limbs_inv_is_inv_or_is_zero,
801 state_1_hi_limbs_inv_is_inv_or_is_zero,
802 state_2_hi_limbs_inv_is_inv_or_is_zero,
803 state_3_hi_limbs_inv_is_inv_or_is_zero,
804 state_0_hi_limbs_inv_is_inv_or_state_0_hi_limbs_is_zero,
805 state_1_hi_limbs_inv_is_inv_or_state_1_hi_limbs_is_zero,
806 state_2_hi_limbs_inv_is_inv_or_state_2_hi_limbs_is_zero,
807 state_3_hi_limbs_inv_is_inv_or_state_3_hi_limbs_is_zero,
808 if_state_0_hi_limbs_are_all_1_then_state_0_lo_limbs_are_all_0,
809 if_state_1_hi_limbs_are_all_1_then_state_1_lo_limbs_are_all_0,
810 if_state_2_hi_limbs_are_all_1_then_state_2_lo_limbs_are_all_0,
811 if_state_3_hi_limbs_are_all_1_then_state_3_lo_limbs_are_all_0,
812 ];
813
814 constraints.extend(if_ci_is_sponge_init_then_rate_is_0);
815 constraints.extend(if_mode_is_hash_and_round_no_is_0_then_states_10_through_15_are_1);
816
817 for round_constant_column_idx in 0..NUM_ROUND_CONSTANTS {
818 let round_constant_column =
819 Self::round_constant_column_by_index(round_constant_column_idx);
820 let round_constant_column_circuit = main_row(round_constant_column);
821 let mut round_constant_constraint_circuit = constant(0);
822 for round_idx in 0..NUM_ROUNDS {
823 let round_constants = Self::tip5_round_constants_by_round_number(round_idx);
824 let round_constant = round_constants[round_constant_column_idx];
825 let round_constant = circuit_builder.b_constant(round_constant);
826 let round_deselector_circuit =
827 Self::round_number_deselector(circuit_builder, &round_number, round_idx);
828 round_constant_constraint_circuit = round_constant_constraint_circuit
829 + round_deselector_circuit
830 * (round_constant_column_circuit.clone() - round_constant);
831 }
832 constraints.push(round_constant_constraint_circuit);
833 }
834
835 constraints
836 }
837
838 fn transition_constraints(
839 circuit_builder: &ConstraintCircuitBuilder<DualRowIndicator>,
840 ) -> Vec<ConstraintCircuitMonad<DualRowIndicator>> {
841 let challenge = |c| circuit_builder.challenge(c);
842 let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b());
843 let constant = |c: u64| circuit_builder.b_constant(c);
844
845 let opcode_hash = opcode(Instruction::Hash);
846 let opcode_sponge_init = opcode(Instruction::SpongeInit);
847 let opcode_sponge_absorb = opcode(Instruction::SpongeAbsorb);
848 let opcode_sponge_squeeze = opcode(Instruction::SpongeSqueeze);
849
850 let current_main_row = |column_idx: Self::MainColumn| {
851 circuit_builder.input(CurrentMain(column_idx.master_main_index()))
852 };
853 let next_main_row = |column_idx: Self::MainColumn| {
854 circuit_builder.input(NextMain(column_idx.master_main_index()))
855 };
856 let current_aux_row = |column_idx: Self::AuxColumn| {
857 circuit_builder.input(CurrentAux(column_idx.master_aux_index()))
858 };
859 let next_aux_row = |column_idx: Self::AuxColumn| {
860 circuit_builder.input(NextAux(column_idx.master_aux_index()))
861 };
862
863 let running_evaluation_initial = circuit_builder.x_constant(EvalArg::default_initial());
864
865 let prepare_chunk_indeterminate =
866 challenge(ChallengeId::ProgramAttestationPrepareChunkIndeterminate);
867 let receive_chunk_indeterminate =
868 challenge(ChallengeId::ProgramAttestationSendChunkIndeterminate);
869 let compress_program_digest_indeterminate =
870 challenge(ChallengeId::CompressProgramDigestIndeterminate);
871 let expected_program_digest = challenge(ChallengeId::CompressedProgramDigest);
872 let hash_input_eval_indeterminate = challenge(ChallengeId::HashInputIndeterminate);
873 let hash_digest_eval_indeterminate = challenge(ChallengeId::HashDigestIndeterminate);
874 let sponge_indeterminate = challenge(ChallengeId::SpongeIndeterminate);
875
876 let mode = current_main_row(Self::MainColumn::Mode);
877 let ci = current_main_row(Self::MainColumn::CI);
878 let round_number = current_main_row(Self::MainColumn::RoundNumber);
879 let running_evaluation_receive_chunk =
880 current_aux_row(Self::AuxColumn::ReceiveChunkRunningEvaluation);
881 let running_evaluation_hash_input =
882 current_aux_row(Self::AuxColumn::HashInputRunningEvaluation);
883 let running_evaluation_hash_digest =
884 current_aux_row(Self::AuxColumn::HashDigestRunningEvaluation);
885 let running_evaluation_sponge = current_aux_row(Self::AuxColumn::SpongeRunningEvaluation);
886
887 let mode_next = next_main_row(Self::MainColumn::Mode);
888 let ci_next = next_main_row(Self::MainColumn::CI);
889 let round_number_next = next_main_row(Self::MainColumn::RoundNumber);
890 let running_evaluation_receive_chunk_next =
891 next_aux_row(Self::AuxColumn::ReceiveChunkRunningEvaluation);
892 let running_evaluation_hash_input_next =
893 next_aux_row(Self::AuxColumn::HashInputRunningEvaluation);
894 let running_evaluation_hash_digest_next =
895 next_aux_row(Self::AuxColumn::HashDigestRunningEvaluation);
896 let running_evaluation_sponge_next = next_aux_row(Self::AuxColumn::SpongeRunningEvaluation);
897
898 let [state_0, state_1, state_2, state_3] =
899 Self::re_compose_states_0_through_3_before_lookup(
900 circuit_builder,
901 Self::indicate_column_index_in_current_main_row,
902 );
903
904 let state_current = [
905 state_0,
906 state_1,
907 state_2,
908 state_3,
909 current_main_row(Self::MainColumn::State4),
910 current_main_row(Self::MainColumn::State5),
911 current_main_row(Self::MainColumn::State6),
912 current_main_row(Self::MainColumn::State7),
913 current_main_row(Self::MainColumn::State8),
914 current_main_row(Self::MainColumn::State9),
915 current_main_row(Self::MainColumn::State10),
916 current_main_row(Self::MainColumn::State11),
917 current_main_row(Self::MainColumn::State12),
918 current_main_row(Self::MainColumn::State13),
919 current_main_row(Self::MainColumn::State14),
920 current_main_row(Self::MainColumn::State15),
921 ];
922
923 let (state_next, hash_function_round_correctly_performs_update) =
924 Self::tip5_constraints_as_circuits(circuit_builder);
925
926 let state_weights = [
927 ChallengeId::StackWeight0,
928 ChallengeId::StackWeight1,
929 ChallengeId::StackWeight2,
930 ChallengeId::StackWeight3,
931 ChallengeId::StackWeight4,
932 ChallengeId::StackWeight5,
933 ChallengeId::StackWeight6,
934 ChallengeId::StackWeight7,
935 ChallengeId::StackWeight8,
936 ChallengeId::StackWeight9,
937 ChallengeId::StackWeight10,
938 ChallengeId::StackWeight11,
939 ChallengeId::StackWeight12,
940 ChallengeId::StackWeight13,
941 ChallengeId::StackWeight14,
942 ChallengeId::StackWeight15,
943 ]
944 .map(challenge);
945
946 let round_number_is_not_num_rounds =
947 Self::round_number_deselector(circuit_builder, &round_number, NUM_ROUNDS);
948
949 let round_number_is_0_through_4_or_round_number_next_is_0 =
950 round_number_is_not_num_rounds * round_number_next.clone();
951
952 let next_mode_is_padding_mode_or_round_number_is_num_rounds_or_increments_by_one =
953 Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad)
954 * (ci.clone() - opcode_sponge_init.clone())
955 * (round_number.clone() - constant(NUM_ROUNDS as u64))
956 * (round_number_next.clone() - round_number.clone() - constant(1));
957
958 let compressed_digest = state_current[..Digest::LEN].iter().fold(
961 running_evaluation_initial.clone(),
962 |acc, digest_element| {
963 acc * compress_program_digest_indeterminate.clone() + digest_element.clone()
964 },
965 );
966 let if_mode_changes_from_program_hashing_then_current_digest_is_expected_program_digest =
967 Self::mode_deselector(circuit_builder, &mode, HashTableMode::ProgramHashing)
968 * Self::select_mode(circuit_builder, &mode_next, HashTableMode::ProgramHashing)
969 * (compressed_digest - expected_program_digest);
970
971 let if_mode_is_program_hashing_and_next_mode_is_sponge_then_ci_next_is_sponge_init =
972 Self::mode_deselector(circuit_builder, &mode, HashTableMode::ProgramHashing)
973 * Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::Sponge)
974 * (ci_next.clone() - opcode_sponge_init.clone());
975
976 let if_round_number_is_not_max_and_ci_is_not_sponge_init_then_ci_doesnt_change =
977 (round_number.clone() - constant(NUM_ROUNDS as u64))
978 * (ci.clone() - opcode_sponge_init.clone())
979 * (ci_next.clone() - ci.clone());
980
981 let if_round_number_is_not_max_and_ci_is_not_sponge_init_then_mode_doesnt_change =
982 (round_number - constant(NUM_ROUNDS as u64))
983 * (ci.clone() - opcode_sponge_init.clone())
984 * (mode_next.clone() - mode.clone());
985
986 let if_mode_is_sponge_then_mode_next_is_sponge_or_hash_or_pad =
987 Self::mode_deselector(circuit_builder, &mode, HashTableMode::Sponge)
988 * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Sponge)
989 * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash)
990 * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad);
991
992 let if_mode_is_hash_then_mode_next_is_hash_or_pad =
993 Self::mode_deselector(circuit_builder, &mode, HashTableMode::Hash)
994 * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash)
995 * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad);
996
997 let if_mode_is_pad_then_mode_next_is_pad =
998 Self::mode_deselector(circuit_builder, &mode, HashTableMode::Pad)
999 * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad);
1000
1001 let difference_of_capacity_registers = state_current[tip5::RATE..]
1002 .iter()
1003 .zip_eq(state_next[tip5::RATE..].iter())
1004 .map(|(current, next)| next.clone() - current.clone())
1005 .collect_vec();
1006 let randomized_sum_of_capacity_differences = state_weights[tip5::RATE..]
1007 .iter()
1008 .zip_eq(difference_of_capacity_registers)
1009 .map(|(weight, state_difference)| weight.clone() * state_difference)
1010 .sum::<ConstraintCircuitMonad<_>>();
1011
1012 let capacity_doesnt_change_at_section_start_when_program_hashing_or_absorbing =
1013 Self::round_number_deselector(circuit_builder, &round_number_next, 0)
1014 * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash)
1015 * Self::select_mode(circuit_builder, &mode_next, HashTableMode::Pad)
1016 * (ci_next.clone() - opcode_sponge_init.clone())
1017 * randomized_sum_of_capacity_differences.clone();
1018
1019 let difference_of_state_registers = state_current
1020 .iter()
1021 .zip_eq(state_next.iter())
1022 .map(|(current, next)| next.clone() - current.clone())
1023 .collect_vec();
1024 let randomized_sum_of_state_differences = state_weights
1025 .iter()
1026 .zip_eq(difference_of_state_registers.iter())
1027 .map(|(weight, state_difference)| weight.clone() * state_difference.clone())
1028 .sum();
1029 let if_round_number_next_is_0_and_ci_next_is_squeeze_then_state_doesnt_change =
1030 Self::round_number_deselector(circuit_builder, &round_number_next, 0)
1031 * Self::instruction_deselector(
1032 circuit_builder,
1033 &ci_next,
1034 Instruction::SpongeSqueeze,
1035 )
1036 * randomized_sum_of_state_differences;
1037
1038 let running_evaluation_hash_input_remains =
1043 running_evaluation_hash_input_next.clone() - running_evaluation_hash_input.clone();
1044 let tip5_input = state_next[..tip5::RATE].to_owned();
1045 let compressed_row_from_processor = tip5_input
1046 .into_iter()
1047 .zip_eq(state_weights[..tip5::RATE].iter())
1048 .map(|(state, weight)| weight.clone() * state)
1049 .sum();
1050
1051 let running_evaluation_hash_input_updates = running_evaluation_hash_input_next
1052 - hash_input_eval_indeterminate * running_evaluation_hash_input
1053 - compressed_row_from_processor;
1054 let running_evaluation_hash_input_is_updated_correctly =
1055 Self::round_number_deselector(circuit_builder, &round_number_next, 0)
1056 * Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::Hash)
1057 * running_evaluation_hash_input_updates
1058 + round_number_next.clone() * running_evaluation_hash_input_remains.clone()
1059 + Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash)
1060 * running_evaluation_hash_input_remains;
1061
1062 let round_number_next_is_num_rounds =
1066 round_number_next.clone() - constant(NUM_ROUNDS as u64);
1067 let running_evaluation_hash_digest_remains =
1068 running_evaluation_hash_digest_next.clone() - running_evaluation_hash_digest.clone();
1069 let hash_digest = state_next[..Digest::LEN].to_owned();
1070 let compressed_row_hash_digest = hash_digest
1071 .into_iter()
1072 .zip_eq(state_weights[..Digest::LEN].iter())
1073 .map(|(state, weight)| weight.clone() * state)
1074 .sum();
1075 let running_evaluation_hash_digest_updates = running_evaluation_hash_digest_next
1076 - hash_digest_eval_indeterminate * running_evaluation_hash_digest
1077 - compressed_row_hash_digest;
1078 let running_evaluation_hash_digest_is_updated_correctly =
1079 Self::round_number_deselector(circuit_builder, &round_number_next, NUM_ROUNDS)
1080 * Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::Hash)
1081 * running_evaluation_hash_digest_updates
1082 + round_number_next_is_num_rounds * running_evaluation_hash_digest_remains.clone()
1083 + Self::select_mode(circuit_builder, &mode_next, HashTableMode::Hash)
1084 * running_evaluation_hash_digest_remains;
1085
1086 let compressed_row_next = state_weights[..tip5::RATE]
1088 .iter()
1089 .zip_eq(state_next[..tip5::RATE].iter())
1090 .map(|(weight, st_next)| weight.clone() * st_next.clone())
1091 .sum();
1092 let running_evaluation_sponge_has_accumulated_ci = running_evaluation_sponge_next.clone()
1093 - sponge_indeterminate * running_evaluation_sponge.clone()
1094 - challenge(ChallengeId::HashCIWeight) * ci_next.clone();
1095 let running_evaluation_sponge_has_accumulated_next_row =
1096 running_evaluation_sponge_has_accumulated_ci.clone() - compressed_row_next;
1097 let if_round_no_next_0_and_ci_next_is_spongy_then_running_evaluation_sponge_updates =
1098 Self::round_number_deselector(circuit_builder, &round_number_next, 0)
1099 * (ci_next.clone() - opcode_hash)
1100 * running_evaluation_sponge_has_accumulated_next_row;
1101
1102 let running_evaluation_sponge_remains =
1103 running_evaluation_sponge_next - running_evaluation_sponge;
1104 let if_round_no_next_is_not_0_then_running_evaluation_sponge_remains =
1105 round_number_next.clone() * running_evaluation_sponge_remains.clone();
1106 let if_ci_next_is_not_spongy_then_running_evaluation_sponge_remains = (ci_next.clone()
1107 - opcode_sponge_init)
1108 * (ci_next.clone() - opcode_sponge_absorb)
1109 * (ci_next - opcode_sponge_squeeze)
1110 * running_evaluation_sponge_remains;
1111 let running_evaluation_sponge_is_updated_correctly =
1112 if_round_no_next_0_and_ci_next_is_spongy_then_running_evaluation_sponge_updates
1113 + if_round_no_next_is_not_0_then_running_evaluation_sponge_remains
1114 + if_ci_next_is_not_spongy_then_running_evaluation_sponge_remains;
1115
1116 let compressed_chunk = state_next[..tip5::RATE]
1119 .iter()
1120 .fold(running_evaluation_initial, |acc, rate_element| {
1121 acc * prepare_chunk_indeterminate.clone() + rate_element.clone()
1122 });
1123 let receive_chunk_running_evaluation_absorbs_chunk_of_instructions =
1124 running_evaluation_receive_chunk_next.clone()
1125 - receive_chunk_indeterminate * running_evaluation_receive_chunk.clone()
1126 - compressed_chunk;
1127 let receive_chunk_running_evaluation_remains =
1128 running_evaluation_receive_chunk_next - running_evaluation_receive_chunk;
1129 let receive_chunk_of_instructions_iff_next_mode_is_prog_hashing_and_next_round_number_is_0 =
1130 Self::round_number_deselector(circuit_builder, &round_number_next, 0)
1131 * Self::mode_deselector(circuit_builder, &mode_next, HashTableMode::ProgramHashing)
1132 * receive_chunk_running_evaluation_absorbs_chunk_of_instructions
1133 + round_number_next * receive_chunk_running_evaluation_remains.clone()
1134 + Self::select_mode(circuit_builder, &mode_next, HashTableMode::ProgramHashing)
1135 * receive_chunk_running_evaluation_remains;
1136
1137 let constraints = vec![
1138 round_number_is_0_through_4_or_round_number_next_is_0,
1139 next_mode_is_padding_mode_or_round_number_is_num_rounds_or_increments_by_one,
1140 receive_chunk_of_instructions_iff_next_mode_is_prog_hashing_and_next_round_number_is_0,
1141 if_mode_changes_from_program_hashing_then_current_digest_is_expected_program_digest,
1142 if_mode_is_program_hashing_and_next_mode_is_sponge_then_ci_next_is_sponge_init,
1143 if_round_number_is_not_max_and_ci_is_not_sponge_init_then_ci_doesnt_change,
1144 if_round_number_is_not_max_and_ci_is_not_sponge_init_then_mode_doesnt_change,
1145 if_mode_is_sponge_then_mode_next_is_sponge_or_hash_or_pad,
1146 if_mode_is_hash_then_mode_next_is_hash_or_pad,
1147 if_mode_is_pad_then_mode_next_is_pad,
1148 capacity_doesnt_change_at_section_start_when_program_hashing_or_absorbing,
1149 if_round_number_next_is_0_and_ci_next_is_squeeze_then_state_doesnt_change,
1150 running_evaluation_hash_input_is_updated_correctly,
1151 running_evaluation_hash_digest_is_updated_correctly,
1152 running_evaluation_sponge_is_updated_correctly,
1153 Self::cascade_log_derivative_update_circuit(
1154 circuit_builder,
1155 Self::MainColumn::State0HighestLkIn,
1156 Self::MainColumn::State0HighestLkOut,
1157 Self::AuxColumn::CascadeState0HighestClientLogDerivative,
1158 ),
1159 Self::cascade_log_derivative_update_circuit(
1160 circuit_builder,
1161 Self::MainColumn::State0MidHighLkIn,
1162 Self::MainColumn::State0MidHighLkOut,
1163 Self::AuxColumn::CascadeState0MidHighClientLogDerivative,
1164 ),
1165 Self::cascade_log_derivative_update_circuit(
1166 circuit_builder,
1167 Self::MainColumn::State0MidLowLkIn,
1168 Self::MainColumn::State0MidLowLkOut,
1169 Self::AuxColumn::CascadeState0MidLowClientLogDerivative,
1170 ),
1171 Self::cascade_log_derivative_update_circuit(
1172 circuit_builder,
1173 Self::MainColumn::State0LowestLkIn,
1174 Self::MainColumn::State0LowestLkOut,
1175 Self::AuxColumn::CascadeState0LowestClientLogDerivative,
1176 ),
1177 Self::cascade_log_derivative_update_circuit(
1178 circuit_builder,
1179 Self::MainColumn::State1HighestLkIn,
1180 Self::MainColumn::State1HighestLkOut,
1181 Self::AuxColumn::CascadeState1HighestClientLogDerivative,
1182 ),
1183 Self::cascade_log_derivative_update_circuit(
1184 circuit_builder,
1185 Self::MainColumn::State1MidHighLkIn,
1186 Self::MainColumn::State1MidHighLkOut,
1187 Self::AuxColumn::CascadeState1MidHighClientLogDerivative,
1188 ),
1189 Self::cascade_log_derivative_update_circuit(
1190 circuit_builder,
1191 Self::MainColumn::State1MidLowLkIn,
1192 Self::MainColumn::State1MidLowLkOut,
1193 Self::AuxColumn::CascadeState1MidLowClientLogDerivative,
1194 ),
1195 Self::cascade_log_derivative_update_circuit(
1196 circuit_builder,
1197 Self::MainColumn::State1LowestLkIn,
1198 Self::MainColumn::State1LowestLkOut,
1199 Self::AuxColumn::CascadeState1LowestClientLogDerivative,
1200 ),
1201 Self::cascade_log_derivative_update_circuit(
1202 circuit_builder,
1203 Self::MainColumn::State2HighestLkIn,
1204 Self::MainColumn::State2HighestLkOut,
1205 Self::AuxColumn::CascadeState2HighestClientLogDerivative,
1206 ),
1207 Self::cascade_log_derivative_update_circuit(
1208 circuit_builder,
1209 Self::MainColumn::State2MidHighLkIn,
1210 Self::MainColumn::State2MidHighLkOut,
1211 Self::AuxColumn::CascadeState2MidHighClientLogDerivative,
1212 ),
1213 Self::cascade_log_derivative_update_circuit(
1214 circuit_builder,
1215 Self::MainColumn::State2MidLowLkIn,
1216 Self::MainColumn::State2MidLowLkOut,
1217 Self::AuxColumn::CascadeState2MidLowClientLogDerivative,
1218 ),
1219 Self::cascade_log_derivative_update_circuit(
1220 circuit_builder,
1221 Self::MainColumn::State2LowestLkIn,
1222 Self::MainColumn::State2LowestLkOut,
1223 Self::AuxColumn::CascadeState2LowestClientLogDerivative,
1224 ),
1225 Self::cascade_log_derivative_update_circuit(
1226 circuit_builder,
1227 Self::MainColumn::State3HighestLkIn,
1228 Self::MainColumn::State3HighestLkOut,
1229 Self::AuxColumn::CascadeState3HighestClientLogDerivative,
1230 ),
1231 Self::cascade_log_derivative_update_circuit(
1232 circuit_builder,
1233 Self::MainColumn::State3MidHighLkIn,
1234 Self::MainColumn::State3MidHighLkOut,
1235 Self::AuxColumn::CascadeState3MidHighClientLogDerivative,
1236 ),
1237 Self::cascade_log_derivative_update_circuit(
1238 circuit_builder,
1239 Self::MainColumn::State3MidLowLkIn,
1240 Self::MainColumn::State3MidLowLkOut,
1241 Self::AuxColumn::CascadeState3MidLowClientLogDerivative,
1242 ),
1243 Self::cascade_log_derivative_update_circuit(
1244 circuit_builder,
1245 Self::MainColumn::State3LowestLkIn,
1246 Self::MainColumn::State3LowestLkOut,
1247 Self::AuxColumn::CascadeState3LowestClientLogDerivative,
1248 ),
1249 ];
1250
1251 [
1252 constraints,
1253 hash_function_round_correctly_performs_update.to_vec(),
1254 ]
1255 .concat()
1256 }
1257
1258 fn terminal_constraints(
1259 circuit_builder: &ConstraintCircuitBuilder<SingleRowIndicator>,
1260 ) -> Vec<ConstraintCircuitMonad<SingleRowIndicator>> {
1261 let challenge = |c| circuit_builder.challenge(c);
1262 let opcode = |instruction: Instruction| circuit_builder.b_constant(instruction.opcode_b());
1263 let constant = |c: u64| circuit_builder.b_constant(c);
1264 let main_row = |column_idx: Self::MainColumn| {
1265 circuit_builder.input(Main(column_idx.master_main_index()))
1266 };
1267
1268 let mode = main_row(Self::MainColumn::Mode);
1269 let round_number = main_row(Self::MainColumn::RoundNumber);
1270
1271 let compress_program_digest_indeterminate =
1272 challenge(ChallengeId::CompressProgramDigestIndeterminate);
1273 let expected_program_digest = challenge(ChallengeId::CompressedProgramDigest);
1274
1275 let max_round_number = constant(NUM_ROUNDS as u64);
1276
1277 let [state_0, state_1, state_2, state_3] =
1278 Self::re_compose_states_0_through_3_before_lookup(
1279 circuit_builder,
1280 Self::indicate_column_index_in_main_row,
1281 );
1282 let state_4 = main_row(Self::MainColumn::State4);
1283 let program_digest = [state_0, state_1, state_2, state_3, state_4];
1284 let compressed_digest = program_digest.into_iter().fold(
1285 circuit_builder.x_constant(EvalArg::default_initial()),
1286 |acc, digest_element| {
1287 acc * compress_program_digest_indeterminate.clone() + digest_element
1288 },
1289 );
1290 let if_mode_is_program_hashing_then_current_digest_is_expected_program_digest =
1291 Self::mode_deselector(circuit_builder, &mode, HashTableMode::ProgramHashing)
1292 * (compressed_digest - expected_program_digest);
1293
1294 let if_mode_is_not_pad_and_ci_is_not_sponge_init_then_round_number_is_max_round_number =
1295 Self::select_mode(circuit_builder, &mode, HashTableMode::Pad)
1296 * (main_row(Self::MainColumn::CI) - opcode(Instruction::SpongeInit))
1297 * (round_number - max_round_number);
1298
1299 vec![
1300 if_mode_is_program_hashing_then_current_digest_is_expected_program_digest,
1301 if_mode_is_not_pad_and_ci_is_not_sponge_init_then_round_number_is_max_round_number,
1302 ]
1303 }
1304}
1305
1306#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Hash, EnumCount, EnumIter)]
1341pub enum HashTableMode {
1342 ProgramHashing,
1347
1348 Sponge,
1352
1353 Hash,
1355
1356 Pad,
1358}
1359
1360impl From<HashTableMode> for u32 {
1361 fn from(mode: HashTableMode) -> Self {
1362 match mode {
1363 HashTableMode::ProgramHashing => 1,
1364 HashTableMode::Sponge => 2,
1365 HashTableMode::Hash => 3,
1366 HashTableMode::Pad => 0,
1367 }
1368 }
1369}
1370
1371impl From<HashTableMode> for u64 {
1372 fn from(mode: HashTableMode) -> Self {
1373 let discriminant: u32 = mode.into();
1374 discriminant.into()
1375 }
1376}
1377
1378impl From<HashTableMode> for BFieldElement {
1379 fn from(mode: HashTableMode) -> Self {
1380 let discriminant: u32 = mode.into();
1381 discriminant.into()
1382 }
1383}