triton_air/table/
jump_stack.rs

1use constraint_circuit::ConstraintCircuitBuilder;
2use constraint_circuit::ConstraintCircuitMonad;
3use constraint_circuit::DualRowIndicator;
4use constraint_circuit::DualRowIndicator::CurrentAux;
5use constraint_circuit::DualRowIndicator::CurrentMain;
6use constraint_circuit::DualRowIndicator::NextAux;
7use constraint_circuit::DualRowIndicator::NextMain;
8use constraint_circuit::SingleRowIndicator;
9use constraint_circuit::SingleRowIndicator::Aux;
10use constraint_circuit::SingleRowIndicator::Main;
11use isa::instruction::Instruction;
12use twenty_first::prelude::BFieldElement;
13
14use crate::AIR;
15use crate::challenge_id::ChallengeId::ClockJumpDifferenceLookupIndeterminate;
16use crate::challenge_id::ChallengeId::JumpStackCiWeight;
17use crate::challenge_id::ChallengeId::JumpStackClkWeight;
18use crate::challenge_id::ChallengeId::JumpStackIndeterminate;
19use crate::challenge_id::ChallengeId::JumpStackJsdWeight;
20use crate::challenge_id::ChallengeId::JumpStackJsoWeight;
21use crate::challenge_id::ChallengeId::JumpStackJspWeight;
22use crate::cross_table_argument::CrossTableArg;
23use crate::cross_table_argument::LookupArg;
24use crate::table_column::JumpStackAuxColumn::ClockJumpDifferenceLookupClientLogDerivative;
25use crate::table_column::JumpStackAuxColumn::RunningProductPermArg;
26use crate::table_column::JumpStackMainColumn::CI;
27use crate::table_column::JumpStackMainColumn::CLK;
28use crate::table_column::JumpStackMainColumn::JSD;
29use crate::table_column::JumpStackMainColumn::JSO;
30use crate::table_column::JumpStackMainColumn::JSP;
31use crate::table_column::MasterAuxColumn;
32use crate::table_column::MasterMainColumn;
33
34#[derive(Debug, Copy, Clone, Eq, PartialEq)]
35pub struct JumpStackTable;
36
37impl crate::private::Seal for JumpStackTable {}
38
39impl AIR for JumpStackTable {
40    type MainColumn = crate::table_column::JumpStackMainColumn;
41    type AuxColumn = crate::table_column::JumpStackAuxColumn;
42
43    fn initial_constraints(
44        circuit_builder: &ConstraintCircuitBuilder<SingleRowIndicator>,
45    ) -> Vec<ConstraintCircuitMonad<SingleRowIndicator>> {
46        let clk = circuit_builder.input(Main(CLK.master_main_index()));
47        let jsp = circuit_builder.input(Main(JSP.master_main_index()));
48        let jso = circuit_builder.input(Main(JSO.master_main_index()));
49        let jsd = circuit_builder.input(Main(JSD.master_main_index()));
50        let ci = circuit_builder.input(Main(CI.master_main_index()));
51        let rppa = circuit_builder.input(Aux(RunningProductPermArg.master_aux_index()));
52        let clock_jump_diff_log_derivative = circuit_builder.input(Aux(
53            ClockJumpDifferenceLookupClientLogDerivative.master_aux_index(),
54        ));
55
56        let processor_perm_indeterminate = circuit_builder.challenge(JumpStackIndeterminate);
57        // note: `clk`, `jsp`, `jso`, and `jsd` are all constrained to be 0 and
58        // can thus be omitted.
59        let compressed_row = circuit_builder.challenge(JumpStackCiWeight) * ci;
60        let rppa_starts_correctly = rppa - (processor_perm_indeterminate - compressed_row);
61
62        // A clock jump difference of 0 is not allowed. Hence, the initial is
63        // recorded.
64        let clock_jump_diff_log_derivative_starts_correctly = clock_jump_diff_log_derivative
65            - circuit_builder.x_constant(LookupArg::default_initial());
66
67        vec![
68            clk,
69            jsp,
70            jso,
71            jsd,
72            rppa_starts_correctly,
73            clock_jump_diff_log_derivative_starts_correctly,
74        ]
75    }
76
77    fn consistency_constraints(
78        _circuit_builder: &ConstraintCircuitBuilder<SingleRowIndicator>,
79    ) -> Vec<ConstraintCircuitMonad<SingleRowIndicator>> {
80        // no further constraints
81        vec![]
82    }
83
84    fn transition_constraints(
85        circuit_builder: &ConstraintCircuitBuilder<DualRowIndicator>,
86    ) -> Vec<ConstraintCircuitMonad<DualRowIndicator>> {
87        let one = || circuit_builder.b_constant(1);
88        let call_opcode =
89            circuit_builder.b_constant(Instruction::Call(BFieldElement::default()).opcode_b());
90        let return_opcode = circuit_builder.b_constant(Instruction::Return.opcode_b());
91        let recurse_or_return_opcode =
92            circuit_builder.b_constant(Instruction::RecurseOrReturn.opcode_b());
93
94        let clk = circuit_builder.input(CurrentMain(CLK.master_main_index()));
95        let ci = circuit_builder.input(CurrentMain(CI.master_main_index()));
96        let jsp = circuit_builder.input(CurrentMain(JSP.master_main_index()));
97        let jso = circuit_builder.input(CurrentMain(JSO.master_main_index()));
98        let jsd = circuit_builder.input(CurrentMain(JSD.master_main_index()));
99        let rppa = circuit_builder.input(CurrentAux(RunningProductPermArg.master_aux_index()));
100        let clock_jump_diff_log_derivative = circuit_builder.input(CurrentAux(
101            ClockJumpDifferenceLookupClientLogDerivative.master_aux_index(),
102        ));
103
104        let clk_next = circuit_builder.input(NextMain(CLK.master_main_index()));
105        let ci_next = circuit_builder.input(NextMain(CI.master_main_index()));
106        let jsp_next = circuit_builder.input(NextMain(JSP.master_main_index()));
107        let jso_next = circuit_builder.input(NextMain(JSO.master_main_index()));
108        let jsd_next = circuit_builder.input(NextMain(JSD.master_main_index()));
109        let rppa_next = circuit_builder.input(NextAux(RunningProductPermArg.master_aux_index()));
110        let clock_jump_diff_log_derivative_next = circuit_builder.input(NextAux(
111            ClockJumpDifferenceLookupClientLogDerivative.master_aux_index(),
112        ));
113
114        let jsp_inc_or_stays =
115            (jsp_next.clone() - jsp.clone() - one()) * (jsp_next.clone() - jsp.clone());
116
117        let jsp_inc_by_one_or_ci_can_return = (jsp_next.clone() - jsp.clone() - one())
118            * (ci.clone() - return_opcode)
119            * (ci.clone() - recurse_or_return_opcode);
120        let jsp_inc_or_jso_stays_or_ci_can_ret =
121            jsp_inc_by_one_or_ci_can_return.clone() * (jso_next.clone() - jso);
122
123        let jsp_inc_or_jsd_stays_or_ci_can_ret =
124            jsp_inc_by_one_or_ci_can_return.clone() * (jsd_next.clone() - jsd);
125
126        let jsp_inc_or_clk_inc_or_ci_call_or_ci_can_ret = jsp_inc_by_one_or_ci_can_return
127            * (clk_next.clone() - clk.clone() - one())
128            * (ci.clone() - call_opcode);
129
130        let compressed_row = circuit_builder.challenge(JumpStackClkWeight) * clk_next.clone()
131            + circuit_builder.challenge(JumpStackCiWeight) * ci_next
132            + circuit_builder.challenge(JumpStackJspWeight) * jsp_next.clone()
133            + circuit_builder.challenge(JumpStackJsoWeight) * jso_next
134            + circuit_builder.challenge(JumpStackJsdWeight) * jsd_next;
135        let rppa_updates_correctly =
136            rppa_next - rppa * (circuit_builder.challenge(JumpStackIndeterminate) - compressed_row);
137
138        let log_derivative_remains =
139            clock_jump_diff_log_derivative_next.clone() - clock_jump_diff_log_derivative.clone();
140        let clk_diff = clk_next - clk;
141        let log_derivative_accumulates = (clock_jump_diff_log_derivative_next
142            - clock_jump_diff_log_derivative)
143            * (circuit_builder.challenge(ClockJumpDifferenceLookupIndeterminate) - clk_diff)
144            - one();
145        let log_derivative_updates_correctly = (jsp_next.clone() - jsp.clone() - one())
146            * log_derivative_accumulates
147            + (jsp_next - jsp) * log_derivative_remains;
148
149        vec![
150            jsp_inc_or_stays,
151            jsp_inc_or_jso_stays_or_ci_can_ret,
152            jsp_inc_or_jsd_stays_or_ci_can_ret,
153            jsp_inc_or_clk_inc_or_ci_call_or_ci_can_ret,
154            rppa_updates_correctly,
155            log_derivative_updates_correctly,
156        ]
157    }
158
159    fn terminal_constraints(
160        _circuit_builder: &ConstraintCircuitBuilder<SingleRowIndicator>,
161    ) -> Vec<ConstraintCircuitMonad<SingleRowIndicator>> {
162        // no further constraints
163        vec![]
164    }
165}