1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use core::borrow::Borrow;

use p3_air::PairBuilder;
use p3_air::{Air, BaseAir};
use p3_field::AbstractField;
use p3_field::Field;
use p3_matrix::Matrix;

use super::columns::{ByteMultCols, BytePreprocessedCols, NUM_BYTE_MULT_COLS};
use super::{ByteChip, ByteOpcode, NUM_BYTE_LOOKUP_CHANNELS};
use crate::air::SP1AirBuilder;

impl<F: Field> BaseAir<F> for ByteChip<F> {
    fn width(&self) -> usize {
        NUM_BYTE_MULT_COLS
    }
}

impl<AB: SP1AirBuilder + PairBuilder> Air<AB> for ByteChip<AB::F> {
    fn eval(&self, builder: &mut AB) {
        let main = builder.main();
        let local_mult = main.row_slice(0);
        let local_mult: &ByteMultCols<AB::Var> = (*local_mult).borrow();

        let prep = builder.preprocessed();
        let prep = prep.row_slice(0);
        let local: &BytePreprocessedCols<AB::Var> = (*prep).borrow();

        // Send all the lookups for each operation.
        for channel in 0..NUM_BYTE_LOOKUP_CHANNELS {
            let channel_f = AB::F::from_canonical_u32(channel);
            let channel = channel as usize;
            for (i, opcode) in ByteOpcode::all().iter().enumerate() {
                let field_op = opcode.as_field::<AB::F>();
                let mult = local_mult.mult_channels[channel].multiplicities[i];
                let shard = local_mult.shard;
                match opcode {
                    ByteOpcode::AND => builder.receive_byte(
                        field_op, local.and, local.b, local.c, shard, channel_f, mult,
                    ),
                    ByteOpcode::OR => builder
                        .receive_byte(field_op, local.or, local.b, local.c, shard, channel_f, mult),
                    ByteOpcode::XOR => builder.receive_byte(
                        field_op, local.xor, local.b, local.c, shard, channel_f, mult,
                    ),
                    ByteOpcode::SLL => builder.receive_byte(
                        field_op, local.sll, local.b, local.c, shard, channel_f, mult,
                    ),
                    ByteOpcode::U8Range => builder.receive_byte(
                        field_op,
                        AB::F::zero(),
                        local.b,
                        local.c,
                        shard,
                        channel_f,
                        mult,
                    ),
                    ByteOpcode::ShrCarry => builder.receive_byte_pair(
                        field_op,
                        local.shr,
                        local.shr_carry,
                        local.b,
                        local.c,
                        shard,
                        channel_f,
                        mult,
                    ),
                    ByteOpcode::LTU => builder.receive_byte(
                        field_op, local.ltu, local.b, local.c, shard, channel_f, mult,
                    ),
                    ByteOpcode::MSB => builder.receive_byte(
                        field_op,
                        local.msb,
                        local.b,
                        AB::F::zero(),
                        shard,
                        channel_f,
                        mult,
                    ),
                    ByteOpcode::U16Range => builder.receive_byte(
                        field_op,
                        local.value_u16,
                        AB::F::zero(),
                        AB::F::zero(),
                        shard,
                        channel_f,
                        mult,
                    ),
                }
            }
        }
    }
}