1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
use std::{cell::RefCell, collections::HashMap};

use triton_opcodes::instruction::LabelledInstruction;
use triton_vm::BFieldElement;

use crate::snippet::{DataType, Snippet};

/// A data structure for describing an inner function predicate to filter with,
/// or to map with.
pub struct RawCode {
    pub function: Vec<LabelledInstruction>,
    pub input_types: Vec<DataType>,
    pub output_types: Vec<DataType>,
    #[allow(clippy::type_complexity)]
    rust_shadowing: Option<Box<RefCell<dyn FnMut(&mut Vec<BFieldElement>)>>>,
}

impl RawCode {
    pub fn new(
        function: Vec<LabelledInstruction>,
        input_types: Vec<DataType>,
        output_types: Vec<DataType>,
    ) -> Self {
        // Verify that 1st line is a label
        assert!(
            function.len() >= 2,
            "Inner function must have at least two lines: a label and a return or recurse"
        );
        assert!(
            matches!(function[0], LabelledInstruction::Label(_)),
            "First line of inner function must be label. Got: {}",
            function[0]
        );
        assert!(
            matches!(
                function.last().unwrap(),
                LabelledInstruction::Instruction(
                    triton_opcodes::instruction::AnInstruction::Return
                ) | LabelledInstruction::Instruction(
                    triton_opcodes::instruction::AnInstruction::Recurse
                )
            ),
            "Last line of inner function must be either return or recurse. Got: {}",
            function.last().unwrap()
        );

        Self {
            function,
            input_types,
            output_types,
            rust_shadowing: None,
        }
    }

    #[allow(clippy::type_complexity)]
    pub fn new_with_shadowing(
        function: Vec<LabelledInstruction>,
        input_types: Vec<DataType>,
        output_types: Vec<DataType>,
        rust_shadowing: Box<RefCell<dyn FnMut(&mut Vec<BFieldElement>)>>,
    ) -> Self {
        // Verify that 1st line is a label
        assert!(
            function.len() >= 2,
            "Inner function must have at least two lines: a label and a return or recurse"
        );
        assert!(
            matches!(function[0], LabelledInstruction::Label(_)),
            "First line of inner function must be label. Got: {}",
            function[0]
        );
        assert!(
            matches!(
                function.last().unwrap(),
                LabelledInstruction::Instruction(
                    triton_opcodes::instruction::AnInstruction::Return
                ) | LabelledInstruction::Instruction(
                    triton_opcodes::instruction::AnInstruction::Recurse
                )
            ),
            "Last line of inner function must be either return or recurse. Got: {}",
            function.last().unwrap()
        );

        Self {
            function,
            input_types,
            output_types,
            rust_shadowing: Some(rust_shadowing),
        }
    }
}

impl RawCode {
    /// Return the entrypoint, label, of the inner function. Used to make a call to this function.
    pub fn entrypoint(&self) -> String {
        match &self.function[0] {
            LabelledInstruction::Instruction(inst) => {
                panic!("First line of inner function must be a label. Got: {inst}")
            }
            LabelledInstruction::Label(label) => label.to_owned(),
        }
    }
}

pub enum InnerFunction {
    RawCode(RawCode),
    Snippet(Box<dyn Snippet>),

    // Used when a snippet is declared somewhere else, and it's not the responsibility of
    // the higher order function to import it.
    NoFunctionBody(NoFunctionBody),
}

pub struct NoFunctionBody {
    pub label_name: String,
    pub input_types: Vec<DataType>,
    pub output_types: Vec<DataType>,
}

impl InnerFunction {
    /// Return the input types this inner function accepts
    pub fn get_input_types(&self) -> Vec<DataType> {
        match self {
            InnerFunction::RawCode(raw) => raw.input_types.clone(),
            InnerFunction::Snippet(f) => f.input_types(),
            InnerFunction::NoFunctionBody(f) => f.input_types.clone(),
        }
    }

    /// Return the expected type of list element this function accepts
    pub fn input_list_element_type(&self) -> DataType {
        self.get_input_types().last().unwrap().to_owned()
    }

    /// Return all input types apart from the element type of the input list.
    /// May be the empty list.
    pub fn additional_inputs(&self) -> Vec<DataType> {
        let mut input_types = self.get_input_types();
        input_types.pop().unwrap();

        input_types
    }

    /// Return the size in words for the additional elements, all elements
    /// apart from the element from the input list.
    pub fn size_of_additional_inputs(&self) -> usize {
        self.additional_inputs().iter().map(|x| x.get_size()).sum()
    }

    /// Return types this function outputs.
    pub fn get_output_types(&self) -> Vec<DataType> {
        match self {
            InnerFunction::RawCode(rc) => rc.output_types.clone(),
            InnerFunction::Snippet(sn) => sn.output_types(),
            InnerFunction::NoFunctionBody(lnat) => lnat.output_types.clone(),
        }
    }

    /// Return the entrypoint, label, of the inner function. Used to make a call to this function.
    pub fn entrypoint(&self) -> String {
        match self {
            InnerFunction::RawCode(rc) => rc.entrypoint(),
            InnerFunction::Snippet(sn) => sn.entrypoint(),
            InnerFunction::NoFunctionBody(sn) => sn.label_name.to_owned(),
        }
    }

    /// For testing purposes, this function can mirror what the TASM code does.
    pub fn rust_shadowing(
        &self,
        std_in: &[BFieldElement],
        secret_in: &[BFieldElement],
        stack: &mut Vec<BFieldElement>,
        memory: &mut HashMap<BFieldElement, BFieldElement>,
    ) {
        match &self {
            InnerFunction::RawCode(rc) => {
                if let Some(func) = &rc.rust_shadowing {
                    let mut func = func.borrow_mut();
                    (*func)(stack)
                } else {
                    panic!("Raw code must have rust shadowing for equivalence testing")
                }
            }
            InnerFunction::Snippet(sn) => {
                sn.rust_shadowing(stack, std_in.to_vec(), secret_in.to_vec(), memory)
            }
            InnerFunction::NoFunctionBody(_lnat) => {
                panic!("Cannot rust shadow inner function without function body")
            }
        };
    }
}