1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
use core::borrow::Borrow;
use p3_air::AirBuilder;
use p3_baby_bear::BabyBear;
use p3_field::AbstractField;
use p3_field::Field;
use p3_field::PrimeField32;
use p3_field::TwoAdicField;
use p3_matrix::Matrix;

use crate::air::BaseAirBuilder;
use crate::air::SP1AirBuilder;
use crate::operations::IsZeroOperation;

use super::ShaExtendChip;
use super::ShaExtendCols;

impl<F: Field> ShaExtendCols<F> {
    pub fn populate_flags(&mut self, i: usize) {
        // The generator of the multiplicative subgroup.
        let g = F::from_canonical_u32(BabyBear::two_adic_generator(4).as_canonical_u32());

        // Populate the columns needed to keep track of cycles of 16 rows.
        self.cycle_16 = g.exp_u64((i + 1) as u64);

        // Populate the columns needed to track the start of a cycle of 16 rows.
        self.cycle_16_start
            .populate_from_field_element(self.cycle_16 - g);

        // Populate the columns needed to track the end of a cycle of 16 rows.
        self.cycle_16_end
            .populate_from_field_element(self.cycle_16 - F::one());

        // Populate the columns needed to keep track of cycles of 48 rows.
        let j = 16 + (i % 48);
        self.i = F::from_canonical_usize(j);
        self.cycle_48[0] = F::from_bool((16..32).contains(&j));
        self.cycle_48[1] = F::from_bool((32..48).contains(&j));
        self.cycle_48[2] = F::from_bool((48..64).contains(&j));
        self.cycle_48_start = self.cycle_48[0] * self.cycle_16_start.result * self.is_real;
        self.cycle_48_end = self.cycle_48[2] * self.cycle_16_end.result * self.is_real;
    }
}

impl ShaExtendChip {
    pub fn eval_flags<AB: SP1AirBuilder>(&self, builder: &mut AB) {
        let main = builder.main();
        let (local, next) = (main.row_slice(0), main.row_slice(1));
        let local: &ShaExtendCols<AB::Var> = (*local).borrow();
        let next: &ShaExtendCols<AB::Var> = (*next).borrow();

        let one = AB::Expr::from(AB::F::one());

        // Generator with order 16 within BabyBear.
        let g = AB::F::from_canonical_u32(BabyBear::two_adic_generator(4).as_canonical_u32());

        // First row of the table must have g^1.
        builder.when_first_row().assert_eq(local.cycle_16, g);

        // First row of the table must have i = 16.
        builder
            .when_first_row()
            .assert_eq(local.i, AB::F::from_canonical_u32(16));

        // Every row's `cycle_16` must be previous multiplied by `g`.
        builder
            .when_transition()
            .assert_eq(local.cycle_16 * g, next.cycle_16);

        // Constrain `cycle_16_start.result` to be `cycle_16 - g == 0`.
        IsZeroOperation::<AB::F>::eval(
            builder,
            local.cycle_16 - AB::Expr::from(g),
            local.cycle_16_start,
            one.clone(),
        );

        // Constrain `cycle_16_end.result` to be `cycle_16 - 1 == 0`. Intuitively g^16 is 1.
        IsZeroOperation::<AB::F>::eval(
            builder,
            local.cycle_16 - AB::Expr::one(),
            local.cycle_16_end,
            one.clone(),
        );

        // Constrain `cycle_48` to be [1, 0, 0] in the first row.
        builder
            .when_first_row()
            .assert_eq(local.cycle_48[0], AB::F::one());
        builder
            .when_first_row()
            .assert_eq(local.cycle_48[1], AB::F::zero());
        builder
            .when_first_row()
            .assert_eq(local.cycle_48[2], AB::F::zero());

        // Shift the indices of `cycles_48` at the end of each 16 rows. Otherwise, keep them the same.
        for i in 0..3 {
            builder
                .when_transition()
                .when(local.cycle_16_end.result)
                .assert_eq(local.cycle_48[i], next.cycle_48[(i + 1) % 3]);
            builder
                .when_transition()
                .when(one.clone() - local.cycle_16_end.result)
                .assert_eq(local.cycle_48[i], next.cycle_48[i]);
            builder.assert_bool(local.cycle_48[i]);
        }

        // cycle_48_start == start of 16-cycle AND first 16-cycle within 48-cycle AND is_real.
        builder.assert_eq(
            local.cycle_16_start.result * local.cycle_48[0] * local.is_real,
            local.cycle_48_start,
        );

        // cycle_48_end == end of 16-cycle AND last 16-cycle within 48-cycle AND is_real.
        builder.assert_eq(
            local.cycle_16_end.result * local.cycle_48[2] * local.is_real,
            local.cycle_48_end,
        );

        // When it's the end of a 48-cycle, the next `i` must be 16.
        builder
            .when_transition()
            .when(local.cycle_16_end.result * local.cycle_48[2])
            .assert_eq(next.i, AB::F::from_canonical_u32(16));

        // When it's not the end of a 48-cycle, the next `i` must be the current plus one.
        builder
            .when_transition()
            .when_not(local.cycle_16_end.result * local.cycle_48[2])
            .assert_eq(local.i + one.clone(), next.i);
    }
}