Skip to main content

sp1_hypercube/ir/
op.rs

1use std::fmt::Display;
2
3use serde::{Deserialize, Serialize};
4use slop_algebra::{ExtensionField, Field};
5
6use crate::{
7    air::{AirInteraction, InteractionScope},
8    ir::{FuncDecl, Shape},
9};
10
11/// A binary operation used in the constraint compiler.
12#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
13pub enum BinOp {
14    /// Addition
15    Add,
16    /// Subtraction
17    Sub,
18    /// Multiply
19    Mul,
20}
21
22/// An operation in the IR.
23///
24/// Operations can appear in the AST, and are used to represent the program.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub enum OpExpr<Expr, ExprExt> {
27    /// An assertion that an expression is zero.
28    AssertZero(Expr),
29    /// A send operation.
30    Send(AirInteraction<Expr>, InteractionScope),
31    /// A receive operation.
32    Receive(AirInteraction<Expr>, InteractionScope),
33    /// A function call.
34    Call(FuncDecl<Expr, ExprExt>),
35    /// A binary operation.
36    BinOp(BinOp, Expr, Expr, Expr),
37    /// A binary operation over the extension field.
38    BinOpExt(BinOp, ExprExt, ExprExt, ExprExt),
39    /// A binary operation over the base field and the extension field.
40    BinOpBaseExt(BinOp, ExprExt, ExprExt, Expr),
41    /// A negation operation.
42    Neg(Expr, Expr),
43    /// A negation operation over the extension field.
44    NegExt(ExprExt, ExprExt),
45    /// A conversion from the base field to the extension field.
46    ExtFromBase(ExprExt, Expr),
47    /// An assertion that an expression over the extension field is zero.
48    AssertExtZero(ExprExt),
49    /// An assignment operation.
50    Assign(Expr, Expr),
51}
52
53impl<F, EF> OpExpr<crate::ir::ExprRef<F>, crate::ir::ExprExtRef<EF>>
54where
55    F: Field,
56    EF: ExtensionField<F>,
57{
58    fn write_interaction<Expr>(
59        f: &mut std::fmt::Formatter<'_>,
60        interaction: &AirInteraction<Expr>,
61        scope: InteractionScope,
62    ) -> std::fmt::Result
63    where
64        Expr: Display,
65    {
66        write!(
67            f,
68            "kind: {}, scope: {scope}, multiplicity: {}, values: [",
69            interaction.kind, interaction.multiplicity
70        )?;
71        for (i, value) in interaction.values.iter().enumerate() {
72            write!(f, "{value}")?;
73            if i < interaction.values.len() - 1 {
74                write!(f, ", ")?;
75            }
76        }
77        write!(f, "]")?;
78        Ok(())
79    }
80}
81
82impl<F, EF> Display for OpExpr<crate::ir::ExprRef<F>, crate::ir::ExprExtRef<EF>>
83where
84    F: Field,
85    EF: ExtensionField<F>,
86{
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        match self {
89            OpExpr::AssertZero(x) => write!(f, "Assert({x} == 0)"),
90            OpExpr::Send(interaction, scope) => {
91                write!(f, "Send(")?;
92                Self::write_interaction(f, interaction, *scope)?;
93                write!(f, ")")?;
94                Ok(())
95            }
96            OpExpr::Receive(interaction, scope) => {
97                write!(f, "Receive(")?;
98                Self::write_interaction(f, interaction, *scope)?;
99                write!(f, ")")?;
100                Ok(())
101            }
102            OpExpr::Assign(a, b) => write!(f, "{a} = {b}"),
103            OpExpr::Call(func) => {
104                match func.output {
105                    Shape::Unit => {}
106                    _ => write!(f, "{:?} = ", func.output)?,
107                }
108                write!(f, "{}(", func.name)?;
109                for (i, inp) in func.input.iter().enumerate() {
110                    write!(f, "{inp:?}")?;
111                    if i < func.input.len() - 1 {
112                        write!(f, ", ")?;
113                    }
114                }
115                write!(f, ")")?;
116                Ok(())
117            }
118            OpExpr::BinOp(op, a, b, c) => match op {
119                BinOp::Add => write!(f, "{a} = {b} + {c}"),
120                BinOp::Sub => write!(f, "{a} = {b} - {c}"),
121                BinOp::Mul => write!(f, "{a} = {b} * {c}"),
122            },
123            OpExpr::BinOpExt(op, a, b, c) => match op {
124                BinOp::Add => write!(f, "{a} = {b} + {c}"),
125                BinOp::Sub => write!(f, "{a} = {b} - {c}"),
126                BinOp::Mul => write!(f, "{a} = {b} * {c}"),
127            },
128            OpExpr::BinOpBaseExt(op, a, b, c) => match op {
129                BinOp::Add => write!(f, "{a} = {b} + {c}"),
130                BinOp::Sub => write!(f, "{a} = {b} - {c}"),
131                BinOp::Mul => write!(f, "{a} = {b} * {c}"),
132            },
133            OpExpr::Neg(a, b) => write!(f, "{a} = -{b}"),
134            OpExpr::NegExt(a, b) => write!(f, "{a} = -{b}"),
135            OpExpr::ExtFromBase(a, b) => write!(f, "{a} = {b}"),
136            OpExpr::AssertExtZero(a) => write!(f, "Assert({a} == 0)"),
137        }
138    }
139}