Skip to main content

sp1_hypercube/ir/
func.rs

1use std::{collections::HashMap, fmt::Display};
2
3use serde::{Deserialize, Serialize};
4use slop_algebra::{ExtensionField, Field};
5
6use crate::ir::{Ast, ExprExtRef, ExprRef, Shape};
7
8/// Whether a parameter to a function is an input to a deterministic output, or if that parameter
9/// itself should be considered a deterministic output.
10///
11/// This is used only for the Picus determinisim checker, hence its name.
12#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
13pub enum PicusArg {
14    /// Input to deterministic outputs.
15    Input,
16    /// A determinstic output.
17    Output,
18    /// Doesn't influence the result. `builder` falls into this category.
19    #[default]
20    Unknown,
21}
22
23/// Attributes of function input parameters
24#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
25pub struct Attribute {
26    /// Whether the parameter is a deterministic output or an input to deterministic outputs. Used
27    /// only for the Picus determinism checker, hence its name.
28    pub picus: PicusArg,
29}
30
31impl Display for Attribute {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match self.picus {
34            PicusArg::Input => write!(f, "#[picus(input)]"),
35            PicusArg::Output => write!(f, "#[picus(output)]"),
36            PicusArg::Unknown => Ok(()),
37        }
38    }
39}
40
41/// Represents the "shape" of a function. It only contains the name, input shape, and output shape
42/// of the function, disregarding what the function actually constraints/computes.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct FuncDecl<Expr, ExprExt> {
45    /// The name of the function call, which is usually the operation name.
46    pub name: String,
47    /// The names and the shapes of the input arguments.
48    pub input: Vec<(String, Attribute, Shape<Expr, ExprExt>)>,
49    /// The shape of the output.
50    pub output: Shape<Expr, ExprExt>,
51}
52
53impl<Expr, ExprExt> FuncDecl<Expr, ExprExt> {
54    /// Crates a new [`FuncDecl`].
55    pub fn new(
56        name: String,
57        input: Vec<(String, Attribute, Shape<Expr, ExprExt>)>,
58        output: Shape<Expr, ExprExt>,
59    ) -> Self {
60        Self { name, input, output }
61    }
62}
63
64impl<F: Field, EF: ExtensionField<F>> FuncDecl<ExprRef<F>, ExprExtRef<EF>> {
65    /// A flattened list of the struct representing the position of Input(x) index.
66    pub fn input_mapping(&self) -> HashMap<usize, String> {
67        let mut mapping = HashMap::new();
68        for (name, _, arg) in &self.input {
69            arg.map_input(name.clone(), &mut mapping);
70        }
71        mapping
72    }
73
74    /// The function output's corresponding Lean type in sp1-lean.
75    pub fn to_output_lean_type(&self) -> String {
76        match self.output {
77            Shape::Unit => "SP1ConstraintList".to_string(),
78            _ => format!("{} × SP1ConstraintList", self.output.to_lean_type()),
79        }
80    }
81}
82
83/// Represents a function, containing its name, input/output shapes, and its body [Ast].
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct Func<Expr, ExprExt> {
86    /// The shape of the function. See [`FuncDecl`] for more details.
87    pub decl: FuncDecl<Expr, ExprExt>,
88    /// The body of the [Func], representing the computations performed and the constraints
89    /// asserted by this function.
90    pub body: Ast<Expr, ExprExt>,
91}
92
93impl<F: Field, EF: ExtensionField<F>> Display for Func<ExprRef<F>, ExprExtRef<EF>> {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        writeln!(f, "fn {}(", self.decl.name)?;
96        for (i, (name, attr, inp)) in self.decl.input.iter().enumerate() {
97            // Print attribute if it's not Unknown
98            match attr.picus {
99                PicusArg::Unknown => write!(f, "    {name}: {inp:?}")?,
100                _ => write!(f, "    {attr} {name}: {inp:?}")?,
101            }
102            if i < self.decl.input.len() - 1 {
103                writeln!(f, ",")?;
104            }
105        }
106        write!(f, ")")?;
107        match self.decl.output {
108            Shape::Unit => {}
109            _ => write!(f, " -> {:?}", self.decl.output)?,
110        }
111        writeln!(f, " {{")?;
112        write!(f, "{}", self.body.to_string_pretty("   "))?;
113        writeln!(f, "}}")
114    }
115}