1use std::collections::HashMap;
2
3use arbitrary::Arbitrary;
4use itertools::Itertools;
5use num_traits::ConstOne;
6use triton_vm::memory_layout::MemoryRegion;
7use triton_vm::prelude::*;
8
9use crate::prelude::*;
10
11const STATIC_MEMORY_FIRST_ADDRESS_AS_U64: u64 = BFieldElement::MAX - 1;
19pub const STATIC_MEMORY_FIRST_ADDRESS: BFieldElement =
20 BFieldElement::new(STATIC_MEMORY_FIRST_ADDRESS_AS_U64);
21pub const STATIC_MEMORY_LAST_ADDRESS: BFieldElement =
22 BFieldElement::new(STATIC_MEMORY_FIRST_ADDRESS_AS_U64 - u32::MAX as u64);
23
24#[derive(Clone, Debug)]
27pub struct Library {
28 seen_snippets: HashMap<String, Vec<LabelledInstruction>>,
30
31 num_allocated_words: u32,
33}
34
35#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
40pub struct StaticAllocation {
41 write_address: BFieldElement,
42 num_words: u32,
43}
44
45impl StaticAllocation {
46 pub fn read_address(&self) -> BFieldElement {
48 let offset = bfe!(self.num_words) - BFieldElement::ONE;
49 self.write_address() + offset
50 }
51
52 pub fn write_address(&self) -> BFieldElement {
54 self.write_address
55 }
56
57 pub fn num_words(&self) -> u32 {
59 self.num_words
60 }
61}
62
63impl Default for Library {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl Library {
70 pub fn kmalloc_memory_region() -> MemoryRegion {
71 MemoryRegion::new(STATIC_MEMORY_LAST_ADDRESS, 1usize << 32)
72 }
73
74 pub fn new() -> Self {
75 Self {
76 seen_snippets: HashMap::default(),
77 num_allocated_words: 0,
78 }
79 }
80
81 pub fn empty() -> Self {
83 Self::new()
84 }
85
86 #[cfg(test)]
87 pub fn with_preallocated_memory(words_statically_allocated: u32) -> Self {
88 Library {
89 num_allocated_words: words_statically_allocated,
90 ..Self::new()
91 }
92 }
93
94 pub fn import(&mut self, snippet: Box<dyn BasicSnippet>) -> String {
103 let dep_entrypoint = snippet.entrypoint();
104
105 let is_new_dependency = !self.seen_snippets.contains_key(&dep_entrypoint);
106 if is_new_dependency {
107 let dep_body = snippet.annotated_code(self);
108 self.seen_snippets.insert(dep_entrypoint.clone(), dep_body);
109 }
110
111 dep_entrypoint
112 }
113
114 pub fn explicit_import(&mut self, name: &str, body: &[LabelledInstruction]) -> String {
122 if !self.seen_snippets.contains_key(name) {
123 self.seen_snippets.insert(name.to_owned(), body.to_vec());
124 }
125
126 name.to_string()
127 }
128
129 pub fn all_external_dependencies(&self) -> Vec<Vec<LabelledInstruction>> {
132 self.seen_snippets
133 .iter()
134 .sorted_by_key(|(k, _)| *k)
135 .map(|(_, code)| code.clone())
136 .collect()
137 }
138
139 pub fn get_all_snippet_names(&self) -> Vec<String> {
142 let mut ret = self.seen_snippets.keys().cloned().collect_vec();
143 ret.sort_unstable();
144 ret
145 }
146
147 pub fn all_imports(&self) -> Vec<LabelledInstruction> {
149 self.all_external_dependencies().concat()
150 }
151
152 pub fn kmalloc(&mut self, num_words: u32) -> StaticAllocation {
160 assert!(num_words > 0, "must allocate a positive number of words");
161 let write_address =
162 STATIC_MEMORY_FIRST_ADDRESS - bfe!(self.num_allocated_words) - bfe!(num_words - 1);
163 self.num_allocated_words = self
164 .num_allocated_words
165 .checked_add(num_words)
166 .expect("Cannot allocate more that u32::MAX words through `kmalloc`.");
167
168 StaticAllocation {
169 write_address,
170 num_words,
171 }
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use triton_vm::prelude::Program;
178 use triton_vm::prelude::triton_asm;
179
180 use super::*;
181 use crate::mmr::calculate_new_peaks_from_leaf_mutation::MmrCalculateNewPeaksFromLeafMutationMtIndices;
182 use crate::test_prelude::*;
183
184 #[derive(Debug, Copy, Clone, BFieldCodec)]
185 struct ZeroSizedType;
186
187 #[derive(Debug)]
188 struct DummyTestSnippetA;
189
190 #[derive(Debug)]
191 struct DummyTestSnippetB;
192
193 #[derive(Debug)]
194 struct DummyTestSnippetC;
195
196 impl BasicSnippet for DummyTestSnippetA {
197 fn inputs(&self) -> Vec<(DataType, String)> {
198 vec![]
199 }
200
201 fn outputs(&self) -> Vec<(DataType, String)> {
202 vec![(DataType::Xfe, "dummy".to_string())]
203 }
204
205 fn entrypoint(&self) -> String {
206 "tasmlib_a_dummy_test_value".to_string()
207 }
208
209 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
210 let b = library.import(Box::new(DummyTestSnippetB));
211 let c = library.import(Box::new(DummyTestSnippetC));
212
213 triton_asm!(
214 {self.entrypoint()}:
215 call {b}
216 call {c}
217 return
218 )
219 }
220 }
221
222 impl BasicSnippet for DummyTestSnippetB {
223 fn inputs(&self) -> Vec<(DataType, String)> {
224 vec![]
225 }
226
227 fn outputs(&self) -> Vec<(DataType, String)> {
228 ["1"; 2]
229 .map(|name| (DataType::Bfe, name.to_string()))
230 .to_vec()
231 }
232
233 fn entrypoint(&self) -> String {
234 "tasmlib_b_dummy_test_value".to_string()
235 }
236
237 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
238 let c = library.import(Box::new(DummyTestSnippetC));
239
240 triton_asm!(
241 {self.entrypoint()}:
242 call {c}
243 call {c}
244 return
245 )
246 }
247 }
248
249 impl BasicSnippet for DummyTestSnippetC {
250 fn inputs(&self) -> Vec<(DataType, String)> {
251 vec![]
252 }
253
254 fn outputs(&self) -> Vec<(DataType, String)> {
255 vec![(DataType::Bfe, "1".to_string())]
256 }
257
258 fn entrypoint(&self) -> String {
259 "tasmlib_c_dummy_test_value".to_string()
260 }
261
262 fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
263 triton_asm!({self.entrypoint()}: push 1 return)
264 }
265 }
266
267 impl Closure for DummyTestSnippetA {
268 type Args = ZeroSizedType;
269
270 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
271 push_encodable(stack, &xfe![[1, 1, 1]]);
272 }
273
274 fn pseudorandom_args(&self, _: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
275 ZeroSizedType
276 }
277 }
278
279 impl Closure for DummyTestSnippetB {
280 type Args = ZeroSizedType;
281
282 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
283 stack.push(bfe!(1));
284 stack.push(bfe!(1));
285 }
286
287 fn pseudorandom_args(&self, _: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
288 ZeroSizedType
289 }
290 }
291
292 impl Closure for DummyTestSnippetC {
293 type Args = ZeroSizedType;
294
295 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
296 stack.push(bfe!(1));
297 }
298
299 fn pseudorandom_args(&self, _: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
300 ZeroSizedType
301 }
302 }
303
304 #[test]
305 fn library_includes() {
306 ShadowedClosure::new(DummyTestSnippetA).test();
307 ShadowedClosure::new(DummyTestSnippetB).test();
308 ShadowedClosure::new(DummyTestSnippetC).test();
309 }
310
311 #[test]
312 fn get_all_snippet_names_test_a() {
313 let mut lib = Library::new();
314 lib.import(Box::new(DummyTestSnippetA));
315 assert_eq!(
316 vec![
317 "tasmlib_a_dummy_test_value",
318 "tasmlib_b_dummy_test_value",
319 "tasmlib_c_dummy_test_value",
320 ],
321 lib.get_all_snippet_names()
322 );
323 }
324
325 #[test]
326 fn get_all_snippet_names_test_b() {
327 let mut lib = Library::new();
328 lib.import(Box::new(DummyTestSnippetB));
329 assert_eq!(
330 vec!["tasmlib_b_dummy_test_value", "tasmlib_c_dummy_test_value"],
331 lib.get_all_snippet_names()
332 );
333 }
334
335 #[test]
336 fn all_imports_as_instruction_lists() {
337 let mut lib = Library::new();
338 lib.import(Box::new(DummyTestSnippetA));
339 lib.import(Box::new(DummyTestSnippetA));
340 lib.import(Box::new(DummyTestSnippetC));
341 let _ret = lib.all_imports();
342 }
343
344 #[test]
345 fn program_is_deterministic() {
346 fn smaller_program() -> Program {
349 let mut library = Library::new();
350 let memcpy = library.import(Box::new(MemCpy));
351 let calculate_new_peaks_from_leaf_mutation =
352 library.import(Box::new(MmrCalculateNewPeaksFromLeafMutationMtIndices));
353
354 let code = triton_asm!(
355 lala_entrypoint:
356 push 1 call {memcpy}
357 call {calculate_new_peaks_from_leaf_mutation}
358
359 return
360 );
361
362 let mut src = code;
363 let mut imports = library.all_imports();
364
365 let all_ext_deps = library.all_external_dependencies();
368 let imports_repeated = all_ext_deps.concat();
369 assert_eq!(imports, imports_repeated);
370
371 src.append(&mut imports);
372
373 Program::new(&src)
374 }
375
376 for _ in 0..100 {
377 let program = smaller_program();
378 let same_program = smaller_program();
379 assert_eq!(program, same_program);
380 }
381 }
382
383 #[test]
384 fn kmalloc_test() {
385 const MINUS_TWO: BFieldElement = BFieldElement::new(BFieldElement::MAX - 1);
386 let mut lib = Library::new();
387
388 let first_chunk = lib.kmalloc(1);
389 assert_eq!(MINUS_TWO, first_chunk.write_address());
390
391 let second_chunk = lib.kmalloc(7);
392 assert_eq!(-bfe!(9), second_chunk.write_address());
393
394 let third_chunk = lib.kmalloc(1000);
395 assert_eq!(-bfe!(1009), third_chunk.write_address());
396 }
397}