1use std::{
2 collections::HashMap,
3 sync::{Arc, LazyLock, Mutex},
4};
5
6use serde::{Deserialize, Serialize};
7use slop_algebra::{extension::BinomialExtensionField, ExtensionField, Field};
8
9use crate::{
10 air::{AirInteraction, InteractionScope},
11 ir::{Attribute, BinOp, ExprExtRef, ExprRef, FuncDecl, IrVar, OpExpr, Shape},
12 InteractionKind,
13};
14
15use sp1_primitives::SP1Field;
16type F = SP1Field;
17type EF = BinomialExtensionField<SP1Field, 4>;
18
19type AstType = Ast<ExprRef<F>, ExprExtRef<EF>>;
20
21pub static GLOBAL_AST: LazyLock<Arc<Mutex<AstType>>> =
26 LazyLock::new(|| Arc::new(Mutex::new(Ast::new())));
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct Ast<Expr, ExprExt> {
31 assignments: Vec<usize>,
32 ext_assignments: Vec<usize>,
33 operations: Vec<OpExpr<Expr, ExprExt>>,
34}
35
36impl<F: Field, EF: ExtensionField<F>> Ast<ExprRef<F>, ExprExtRef<EF>> {
37 #[must_use]
39 pub fn new() -> Self {
40 Self { assignments: vec![], ext_assignments: vec![], operations: vec![] }
41 }
42
43 pub fn alloc(&mut self) -> ExprRef<F> {
50 let id = self.assignments.len();
51 self.assignments.push(self.operations.len());
52 ExprRef::Expr(id)
53 }
54
55 pub fn alloc_array<const N: usize>(&mut self) -> [ExprRef<F>; N] {
58 core::array::from_fn(|_| self.alloc())
59 }
60
61 pub fn assign(&mut self, a: ExprRef<F>, b: ExprRef<F>) {
63 let op = OpExpr::Assign(a, b);
64 self.operations.push(op);
65 }
66
67 pub fn alloc_ext(&mut self) -> ExprExtRef<EF> {
69 let id = self.ext_assignments.len();
70 self.ext_assignments.push(self.operations.len());
71 ExprExtRef::Expr(id)
72 }
73
74 pub fn assert_zero(&mut self, x: ExprRef<F>) {
76 let op = OpExpr::AssertZero(x);
77 self.operations.push(op);
78 }
79
80 pub fn assert_ext_zero(&mut self, x: ExprExtRef<EF>) {
82 let op = OpExpr::AssertExtZero(x);
83 self.operations.push(op);
84 }
85
86 pub fn bin_op(&mut self, op: BinOp, a: ExprRef<F>, b: ExprRef<F>) -> ExprRef<F> {
89 let result = self.alloc();
90 let op = OpExpr::BinOp(op, result, a, b);
91 self.operations.push(op);
92 result
93 }
94
95 pub fn negate(&mut self, a: ExprRef<F>) -> ExprRef<F> {
97 let result = self.alloc();
98 let op = OpExpr::Neg(result, a);
99 self.operations.push(op);
100 result
101 }
102
103 pub fn bin_op_ext(
105 &mut self,
106 op: BinOp,
107 a: ExprExtRef<EF>,
108 b: ExprExtRef<EF>,
109 ) -> ExprExtRef<EF> {
110 let result = self.alloc_ext();
111 let op = OpExpr::BinOpExt(op, result, a, b);
112 self.operations.push(op);
113 result
114 }
115
116 pub fn bin_op_base_ext(
118 &mut self,
119 op: BinOp,
120 a: ExprExtRef<EF>,
121 b: ExprRef<F>,
122 ) -> ExprExtRef<EF> {
123 let result = self.alloc_ext();
124 let op = OpExpr::BinOpBaseExt(op, result, a, b);
125 self.operations.push(op);
126 result
127 }
128
129 pub fn neg_ext(&mut self, a: ExprExtRef<EF>) -> ExprExtRef<EF> {
131 let result = self.alloc_ext();
132 let op = OpExpr::NegExt(result, a);
133 self.operations.push(op);
134 result
135 }
136
137 pub fn ext_from_base(&mut self, a: ExprRef<F>) -> ExprExtRef<EF> {
139 let result = self.alloc_ext();
140 let op = OpExpr::ExtFromBase(result, a);
141 self.operations.push(op);
142 result
143 }
144
145 pub fn send(&mut self, message: AirInteraction<ExprRef<F>>, scope: InteractionScope) {
147 let op = OpExpr::Send(message, scope);
148 self.operations.push(op);
149 }
150
151 pub fn receive(&mut self, message: AirInteraction<ExprRef<F>>, scope: InteractionScope) {
153 let op = OpExpr::Receive(message, scope);
154 self.operations.push(op);
155 }
156
157 #[must_use]
159 pub fn to_string_pretty(&self, prefix: &str) -> String {
160 let mut s = String::new();
161 for op in &self.operations {
162 s.push_str(&format!("{prefix}{op}\n"));
163 }
164 s
165 }
166
167 pub fn call_operation(
169 &mut self,
170 name: String,
171 inputs: Vec<(String, Attribute, Shape<ExprRef<F>, ExprExtRef<EF>>)>,
172 output: Shape<ExprRef<F>, ExprExtRef<EF>>,
173 ) {
174 let func = FuncDecl::new(name, inputs, output);
175 let op = OpExpr::Call(func);
176 self.operations.push(op);
177 }
178
179 #[must_use]
184 pub fn to_lean_components(
185 &self,
186 mapping: &HashMap<usize, String>,
187 ) -> (Vec<String>, Vec<String>, usize) {
188 let mut steps: Vec<String> = Vec::default();
189 let mut calls: usize = 0;
190 let mut constraints: Vec<String> = Vec::default();
191
192 for opexpr in &self.operations {
193 match opexpr {
194 OpExpr::AssertZero(expr) => {
195 constraints.push(format!("(.assertZero {})", expr.to_lean_string(mapping)));
196 }
197 OpExpr::Neg(a, b) => {
198 steps.push(format!(
199 "let {} : Fin KB := -{}",
200 a.expr_to_lean_string(),
201 b.to_lean_string(mapping),
202 ));
203 }
204 OpExpr::BinOp(op, result, a, b) => {
205 let result_str = result.expr_to_lean_string();
206 let a_str = a.to_lean_string(mapping);
207 let b_str = b.to_lean_string(mapping);
208 match op {
209 BinOp::Add => {
210 steps.push(format!("let {result_str} : Fin KB := {a_str} + {b_str}"));
211 }
212 BinOp::Sub => {
213 steps.push(format!("let {result_str} : Fin KB := {a_str} - {b_str}"));
214 }
215 BinOp::Mul => {
216 steps.push(format!("let {result_str} : Fin KB := {a_str} * {b_str}"));
217 }
218 }
219 }
220 OpExpr::Send(interaction, _) => match interaction.kind {
221 InteractionKind::Byte
222 | InteractionKind::State
223 | InteractionKind::Memory
224 | InteractionKind::Program => {
225 constraints.push(format!(
226 "(.send {} {})",
227 interaction.to_lean_string(mapping),
228 interaction.multiplicity.to_lean_string(mapping)
229 ));
230 }
231 _ => {}
232 },
233 OpExpr::Receive(interaction, _) => match interaction.kind {
234 InteractionKind::Byte
235 | InteractionKind::State
236 | InteractionKind::Memory
237 | InteractionKind::Program => {
238 constraints.push(format!(
239 "(.receive {} {})",
240 interaction.to_lean_string(mapping),
241 interaction.multiplicity.to_lean_string(mapping),
242 ));
243 }
244 _ => {}
245 },
246 OpExpr::Call(decl) => {
247 let mut step = String::new();
248 match decl.output {
249 Shape::Unit => {
250 step.push_str(&format!("let CS{calls} : SP1ConstraintList := "));
251 }
252 _ => {
253 step.push_str(&format!(
254 "let ⟨{}, CS{}⟩ := ",
255 decl.output.to_lean_destructor(),
256 calls,
257 ));
258 }
259 }
260
261 step.push_str(&format!("{}.constraints", decl.name));
262
263 for input in &decl.input {
264 step.push(' ');
265 step.push_str(&input.2.to_lean_constructor(mapping));
266 }
267
268 calls += 1;
269 steps.push(step);
270 }
271 OpExpr::Assign(ExprRef::IrVar(IrVar::OutputArg(_)), _) => {
272 }
274 _ => todo!(),
275 }
276 }
277
278 (steps, constraints, calls)
279 }
280}
281
282impl<F: Field, EF: ExtensionField<F>> Default for Ast<ExprRef<F>, ExprExtRef<EF>> {
283 fn default() -> Self {
284 Self::new()
285 }
286}