tasm_lib/
library.rs

1use std::collections::HashMap;
2
3use arbitrary::Arbitrary;
4use itertools::Itertools;
5use num_traits::ConstOne;
6use triton_vm::memory_layout::MemoryRegion;
7use triton_vm::prelude::*;
8
9use crate::prelude::*;
10
11/// By [convention](crate::memory), the last full memory page is reserved for the static allocator.
12/// For convenience during [debugging],[^1] the static allocator starts at the last address of that
13/// page, and grows downwards.
14///
15/// [^1]: and partly for historic reasons
16///
17/// [debugging]: crate::maybe_write_debuggable_vm_state_to_disk
18const STATIC_MEMORY_FIRST_ADDRESS_AS_U64: u64 = BFieldElement::MAX - 1;
19pub const STATIC_MEMORY_FIRST_ADDRESS: BFieldElement =
20    BFieldElement::new(STATIC_MEMORY_FIRST_ADDRESS_AS_U64);
21pub const STATIC_MEMORY_LAST_ADDRESS: BFieldElement =
22    BFieldElement::new(STATIC_MEMORY_FIRST_ADDRESS_AS_U64 - u32::MAX as u64);
23
24/// Represents a set of imports for a single Program or Snippet, and moreover tracks some data used
25/// for initializing the [memory allocator](crate::memory).
26#[derive(Clone, Debug)]
27pub struct Library {
28    /// Imported dependencies.
29    seen_snippets: HashMap<String, Vec<LabelledInstruction>>,
30
31    /// The number of statically allocated words
32    num_allocated_words: u32,
33}
34
35/// Represents a [static memory allocation][kmalloc] within Triton VM.
36/// Both its location within Triton VM's memory and its size and are fix.
37///
38/// [kmalloc]: Library::kmalloc
39#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
40pub struct StaticAllocation {
41    write_address: BFieldElement,
42    num_words: u32,
43}
44
45impl StaticAllocation {
46    /// The address from which the allocated memory can be read.
47    pub fn read_address(&self) -> BFieldElement {
48        let offset = bfe!(self.num_words) - BFieldElement::ONE;
49        self.write_address() + offset
50    }
51
52    /// The address to which the allocated memory can be written.
53    pub fn write_address(&self) -> BFieldElement {
54        self.write_address
55    }
56
57    /// The number of words allocated in this memory block.
58    pub fn num_words(&self) -> u32 {
59        self.num_words
60    }
61}
62
63impl Default for Library {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl Library {
70    pub fn kmalloc_memory_region() -> MemoryRegion {
71        MemoryRegion::new(STATIC_MEMORY_LAST_ADDRESS, 1usize << 32)
72    }
73
74    pub fn new() -> Self {
75        Self {
76            seen_snippets: HashMap::default(),
77            num_allocated_words: 0,
78        }
79    }
80
81    /// Create an empty library.
82    pub fn empty() -> Self {
83        Self::new()
84    }
85
86    #[cfg(test)]
87    pub fn with_preallocated_memory(words_statically_allocated: u32) -> Self {
88        Library {
89            num_allocated_words: words_statically_allocated,
90            ..Self::new()
91        }
92    }
93
94    /// Import `T: Snippet`.
95    ///
96    /// Recursively imports `T`'s dependencies.
97    /// Does not import the snippets with the same entrypoint twice.
98    ///
99    /// Avoid cyclic dependencies by only calling `T::function_code()` which
100    /// may call `.import()` if `.import::<T>()` wasn't already called once.
101    // todo: Above comment is not overly clear. Improve it.
102    pub fn import(&mut self, snippet: Box<dyn BasicSnippet>) -> String {
103        let dep_entrypoint = snippet.entrypoint();
104
105        let is_new_dependency = !self.seen_snippets.contains_key(&dep_entrypoint);
106        if is_new_dependency {
107            let dep_body = snippet.annotated_code(self);
108            self.seen_snippets.insert(dep_entrypoint.clone(), dep_body);
109        }
110
111        dep_entrypoint
112    }
113
114    /// Import code that does not implement the `Snippet` trait
115    ///
116    /// If possible, you should use the [`import`](Self::import) method as
117    /// it gives better protections and allows you to test functions in
118    /// isolation. This method is intended to add function to the assembly
119    /// that you have defined inline and where a function call is needed due to
120    /// e.g. a dynamic counter.
121    pub fn explicit_import(&mut self, name: &str, body: &[LabelledInstruction]) -> String {
122        if !self.seen_snippets.contains_key(name) {
123            self.seen_snippets.insert(name.to_owned(), body.to_vec());
124        }
125
126        name.to_string()
127    }
128
129    /// Return a list of all external dependencies sorted by name. All snippets are sorted
130    /// alphabetically to ensure that generated programs are deterministic.
131    pub fn all_external_dependencies(&self) -> Vec<Vec<LabelledInstruction>> {
132        self.seen_snippets
133            .iter()
134            .sorted_by_key(|(k, _)| *k)
135            .map(|(_, code)| code.clone())
136            .collect()
137    }
138
139    /// Return the name of all imported snippets, sorted alphabetically to ensure that output is
140    /// deterministic.
141    pub fn get_all_snippet_names(&self) -> Vec<String> {
142        let mut ret = self.seen_snippets.keys().cloned().collect_vec();
143        ret.sort_unstable();
144        ret
145    }
146
147    /// Return a list of instructions containing all imported snippets.
148    pub fn all_imports(&self) -> Vec<LabelledInstruction> {
149        self.all_external_dependencies().concat()
150    }
151
152    /// Statically allocate `num_words` words of memory.
153    ///
154    /// # Panics
155    ///
156    /// Panics if
157    /// - `num_words` is zero,
158    /// - the total number of statically allocated words exceeds `u32::MAX`.
159    pub fn kmalloc(&mut self, num_words: u32) -> StaticAllocation {
160        assert!(num_words > 0, "must allocate a positive number of words");
161        let write_address =
162            STATIC_MEMORY_FIRST_ADDRESS - bfe!(self.num_allocated_words) - bfe!(num_words - 1);
163        self.num_allocated_words = self
164            .num_allocated_words
165            .checked_add(num_words)
166            .expect("Cannot allocate more that u32::MAX words through `kmalloc`.");
167
168        StaticAllocation {
169            write_address,
170            num_words,
171        }
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use triton_vm::prelude::Program;
178    use triton_vm::prelude::triton_asm;
179
180    use super::*;
181    use crate::mmr::calculate_new_peaks_from_leaf_mutation::MmrCalculateNewPeaksFromLeafMutationMtIndices;
182    use crate::test_prelude::*;
183
184    #[derive(Debug, Copy, Clone, BFieldCodec)]
185    struct ZeroSizedType;
186
187    #[derive(Debug)]
188    struct DummyTestSnippetA;
189
190    #[derive(Debug)]
191    struct DummyTestSnippetB;
192
193    #[derive(Debug)]
194    struct DummyTestSnippetC;
195
196    impl BasicSnippet for DummyTestSnippetA {
197        fn inputs(&self) -> Vec<(DataType, String)> {
198            vec![]
199        }
200
201        fn outputs(&self) -> Vec<(DataType, String)> {
202            vec![(DataType::Xfe, "dummy".to_string())]
203        }
204
205        fn entrypoint(&self) -> String {
206            "tasmlib_a_dummy_test_value".to_string()
207        }
208
209        fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
210            let b = library.import(Box::new(DummyTestSnippetB));
211            let c = library.import(Box::new(DummyTestSnippetC));
212
213            triton_asm!(
214                {self.entrypoint()}:
215                    call {b}
216                    call {c}
217                    return
218            )
219        }
220    }
221
222    impl BasicSnippet for DummyTestSnippetB {
223        fn inputs(&self) -> Vec<(DataType, String)> {
224            vec![]
225        }
226
227        fn outputs(&self) -> Vec<(DataType, String)> {
228            ["1"; 2]
229                .map(|name| (DataType::Bfe, name.to_string()))
230                .to_vec()
231        }
232
233        fn entrypoint(&self) -> String {
234            "tasmlib_b_dummy_test_value".to_string()
235        }
236
237        fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
238            let c = library.import(Box::new(DummyTestSnippetC));
239
240            triton_asm!(
241                {self.entrypoint()}:
242                    call {c}
243                    call {c}
244                    return
245            )
246        }
247    }
248
249    impl BasicSnippet for DummyTestSnippetC {
250        fn inputs(&self) -> Vec<(DataType, String)> {
251            vec![]
252        }
253
254        fn outputs(&self) -> Vec<(DataType, String)> {
255            vec![(DataType::Bfe, "1".to_string())]
256        }
257
258        fn entrypoint(&self) -> String {
259            "tasmlib_c_dummy_test_value".to_string()
260        }
261
262        fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
263            triton_asm!({self.entrypoint()}: push 1 return)
264        }
265    }
266
267    impl Closure for DummyTestSnippetA {
268        type Args = ZeroSizedType;
269
270        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
271            push_encodable(stack, &xfe![[1, 1, 1]]);
272        }
273
274        fn pseudorandom_args(&self, _: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
275            ZeroSizedType
276        }
277    }
278
279    impl Closure for DummyTestSnippetB {
280        type Args = ZeroSizedType;
281
282        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
283            stack.push(bfe!(1));
284            stack.push(bfe!(1));
285        }
286
287        fn pseudorandom_args(&self, _: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
288            ZeroSizedType
289        }
290    }
291
292    impl Closure for DummyTestSnippetC {
293        type Args = ZeroSizedType;
294
295        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
296            stack.push(bfe!(1));
297        }
298
299        fn pseudorandom_args(&self, _: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
300            ZeroSizedType
301        }
302    }
303
304    #[test]
305    fn library_includes() {
306        ShadowedClosure::new(DummyTestSnippetA).test();
307        ShadowedClosure::new(DummyTestSnippetB).test();
308        ShadowedClosure::new(DummyTestSnippetC).test();
309    }
310
311    #[test]
312    fn get_all_snippet_names_test_a() {
313        let mut lib = Library::new();
314        lib.import(Box::new(DummyTestSnippetA));
315        assert_eq!(
316            vec![
317                "tasmlib_a_dummy_test_value",
318                "tasmlib_b_dummy_test_value",
319                "tasmlib_c_dummy_test_value",
320            ],
321            lib.get_all_snippet_names()
322        );
323    }
324
325    #[test]
326    fn get_all_snippet_names_test_b() {
327        let mut lib = Library::new();
328        lib.import(Box::new(DummyTestSnippetB));
329        assert_eq!(
330            vec!["tasmlib_b_dummy_test_value", "tasmlib_c_dummy_test_value"],
331            lib.get_all_snippet_names()
332        );
333    }
334
335    #[test]
336    fn all_imports_as_instruction_lists() {
337        let mut lib = Library::new();
338        lib.import(Box::new(DummyTestSnippetA));
339        lib.import(Box::new(DummyTestSnippetA));
340        lib.import(Box::new(DummyTestSnippetC));
341        let _ret = lib.all_imports();
342    }
343
344    #[test]
345    fn program_is_deterministic() {
346        // Ensure that a generated program is deterministic, by checking that the imports
347        // are always sorted the same way.
348        fn smaller_program() -> Program {
349            let mut library = Library::new();
350            let memcpy = library.import(Box::new(MemCpy));
351            let calculate_new_peaks_from_leaf_mutation =
352                library.import(Box::new(MmrCalculateNewPeaksFromLeafMutationMtIndices));
353
354            let code = triton_asm!(
355                lala_entrypoint:
356                    push 1 call {memcpy}
357                    call {calculate_new_peaks_from_leaf_mutation}
358
359                    return
360            );
361
362            let mut src = code;
363            let mut imports = library.all_imports();
364
365            // Sanity check on `all_external_dependencies`, checking that they are
366            // *also* sorted alphabetically.
367            let all_ext_deps = library.all_external_dependencies();
368            let imports_repeated = all_ext_deps.concat();
369            assert_eq!(imports, imports_repeated);
370
371            src.append(&mut imports);
372
373            Program::new(&src)
374        }
375
376        for _ in 0..100 {
377            let program = smaller_program();
378            let same_program = smaller_program();
379            assert_eq!(program, same_program);
380        }
381    }
382
383    #[test]
384    fn kmalloc_test() {
385        const MINUS_TWO: BFieldElement = BFieldElement::new(BFieldElement::MAX - 1);
386        let mut lib = Library::new();
387
388        let first_chunk = lib.kmalloc(1);
389        assert_eq!(MINUS_TWO, first_chunk.write_address());
390
391        let second_chunk = lib.kmalloc(7);
392        assert_eq!(-bfe!(9), second_chunk.write_address());
393
394        let third_chunk = lib.kmalloc(1000);
395        assert_eq!(-bfe!(1009), third_chunk.write_address());
396    }
397}