Skip to main content

sp1_hypercube/ir/
compiler.rs

1use std::collections::BTreeMap;
2
3use crate::{
4    air::{AirInteraction, InteractionScope, MachineAir, MessageBuilder},
5    ir::{Ast, Attribute, ExprExtRef, ExprRef, Func, Shape, GLOBAL_AST},
6};
7use slop_air::{AirBuilder, AirBuilderWithPublicValues, ExtensionBuilder, PairBuilder};
8use slop_matrix::dense::RowMajorMatrix;
9
10use crate::ir::expr_impl::{Expr, ExprExt, EF, F};
11
12/// The constraint compiler that records the constraints of a chip.
13#[derive(Clone, Debug)]
14pub struct ConstraintCompiler {
15    public_values: Vec<Expr>,
16    preprocessed: RowMajorMatrix<Expr>,
17    main: RowMajorMatrix<Expr>,
18    modules: BTreeMap<String, Func<Expr, ExprExt>>,
19    parent: Option<Ast<ExprRef<F>, ExprExtRef<EF>>>,
20}
21
22impl ConstraintCompiler {
23    /// Creates a new [`ConstraintCompiler`]
24    pub fn new<A: MachineAir<F>>(air: &A, num_public_values: usize) -> Self {
25        let preprocessed_width = air.preprocessed_width();
26        let main_width = air.width();
27        Self::with_sizes(num_public_values, preprocessed_width, main_width)
28    }
29
30    /// Creates a new [`ConstraintCompiler`] with specific dimensions.
31    pub fn with_sizes(
32        num_public_values: usize,
33        preprocessed_width: usize,
34        main_width: usize,
35    ) -> Self {
36        // Initialize the global AST to empty.
37        let mut ast = GLOBAL_AST.lock().unwrap();
38        *ast = Ast::new();
39
40        // Initialize the public values.
41        let public_values = (0..num_public_values).map(Expr::public).collect();
42        // Initialize the preprocessed and main traces.
43        let preprocessed = (0..preprocessed_width).map(Expr::preprocessed).collect();
44        let preprocessed = RowMajorMatrix::new(preprocessed, preprocessed_width);
45        let main = (0..main_width).map(Expr::main).collect();
46        let main = RowMajorMatrix::new(main, main_width);
47
48        Self { public_values, preprocessed, main, modules: BTreeMap::new(), parent: None }
49    }
50
51    /// Returns the currently recorded AST.
52    pub fn ast(&self) -> Ast<ExprRef<F>, ExprExtRef<EF>> {
53        let ast = GLOBAL_AST.lock().unwrap();
54        ast.clone()
55    }
56
57    fn region(&self) -> Self {
58        let parent = self.ast();
59        let mut ast = GLOBAL_AST.lock().unwrap();
60        *ast = Ast::new();
61        Self {
62            public_values: self.public_values.clone(),
63            preprocessed: self.preprocessed.clone(),
64            main: self.main.clone(),
65            modules: BTreeMap::new(),
66            parent: Some(parent),
67        }
68    }
69
70    /// Records a module (that is usually just a function call that represents an operation).
71    pub fn register_module(
72        &mut self,
73        name: String,
74        params: Vec<(String, Attribute, Shape<ExprRef<F>, ExprExtRef<EF>>)>,
75        body: impl FnOnce(&mut Self) -> Shape<ExprRef<F>, ExprExtRef<EF>>,
76    ) {
77        let mut body_builder = self.region();
78        let result = body(&mut body_builder);
79        let body = body_builder.ast();
80
81        let decl = crate::ir::FuncDecl::new(name.clone(), params, result);
82        self.modules.append(&mut body_builder.modules);
83        self.modules.insert(name, Func { decl, body });
84    }
85
86    /// The modules that has been recorded.
87    #[must_use]
88    pub fn modules(&self) -> &BTreeMap<String, Func<Expr, ExprExt>> {
89        &self.modules
90    }
91
92    /// The total number of cols of the chip.
93    #[must_use]
94    pub fn num_cols(&self) -> usize {
95        self.main.width
96    }
97}
98
99impl Drop for ConstraintCompiler {
100    fn drop(&mut self) {
101        if let Some(parent) = self.parent.take() {
102            let mut ast = GLOBAL_AST.lock().unwrap();
103            *ast = parent;
104        }
105    }
106}
107
108impl AirBuilder for ConstraintCompiler {
109    type F = F;
110    type Expr = Expr;
111    type Var = Expr;
112    type M = RowMajorMatrix<Expr>;
113
114    fn main(&self) -> Self::M {
115        self.main.clone()
116    }
117
118    fn is_first_row(&self) -> Self::Expr {
119        unreachable!("first row is not supported")
120    }
121
122    fn is_last_row(&self) -> Self::Expr {
123        unreachable!("last row is not supported")
124    }
125
126    fn is_transition_window(&self, _size: usize) -> Self::Expr {
127        unreachable!("transition window is not supported")
128    }
129
130    fn assert_zero<I: Into<Self::Expr>>(&mut self, x: I) {
131        let x = x.into();
132        let mut ast = GLOBAL_AST.lock().unwrap();
133        ast.assert_zero(x);
134    }
135}
136
137impl MessageBuilder<AirInteraction<Expr>> for ConstraintCompiler {
138    fn send(&mut self, message: AirInteraction<Expr>, scope: InteractionScope) {
139        let mut ast = GLOBAL_AST.lock().unwrap();
140        ast.send(message, scope);
141    }
142
143    fn receive(&mut self, message: AirInteraction<Expr>, scope: InteractionScope) {
144        let mut ast = GLOBAL_AST.lock().unwrap();
145        ast.receive(message, scope);
146    }
147}
148
149impl PairBuilder for ConstraintCompiler {
150    fn preprocessed(&self) -> Self::M {
151        self.preprocessed.clone()
152    }
153}
154
155impl AirBuilderWithPublicValues for ConstraintCompiler {
156    type PublicVar = Expr;
157
158    fn public_values(&self) -> &[Self::PublicVar] {
159        &self.public_values
160    }
161}
162
163impl ExtensionBuilder for ConstraintCompiler {
164    type EF = EF;
165    type ExprEF = ExprExt;
166    type VarEF = ExprExt;
167
168    fn assert_zero_ext<I>(&mut self, x: I)
169    where
170        I: Into<Self::ExprEF>,
171    {
172        let x = x.into();
173        let mut ast = GLOBAL_AST.lock().unwrap();
174        ast.assert_ext_zero(x);
175    }
176}