1use triton_vm::prelude::*;
2
3use crate::arithmetic::u32::safe_add::SafeAdd;
4use crate::arithmetic::u32::safe_sub::SafeSub;
5use crate::arithmetic::u64::and::And;
6use crate::arithmetic::u64::leading_zeros::LeadingZeros;
7use crate::arithmetic::u64::lt::Lt;
8use crate::arithmetic::u64::or::Or;
9use crate::arithmetic::u64::shift_left::ShiftLeft;
10use crate::arithmetic::u64::shift_right::ShiftRight;
11use crate::arithmetic::u64::sub::Sub;
12use crate::prelude::*;
13
14#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
15pub struct DivMod;
16
17impl DivMod {
18 pub const DIVISION_BY_ZERO_ERROR_ID: i128 = 420;
19}
20
21impl BasicSnippet for DivMod {
22 fn inputs(&self) -> Vec<(DataType, String)> {
23 ["numerator", "denominator"]
24 .map(|name| (DataType::U64, name.to_string()))
25 .to_vec()
26 }
27
28 fn outputs(&self) -> Vec<(DataType, String)> {
29 ["quotient", "remainder"]
30 .map(|name| (DataType::U64, name.to_string()))
31 .to_vec()
32 }
33
34 fn entrypoint(&self) -> String {
35 "tasmlib_arithmetic_u64_div_mod".to_string()
36 }
37
38 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
39 let shift_right_u64 = library.import(Box::new(ShiftRight));
40 let shift_left_u64 = library.import(Box::new(ShiftLeft));
41 let and_u64 = library.import(Box::new(And));
42 let lt_u64 = library.import(Box::new(Lt));
43 let or_u64 = library.import(Box::new(Or));
44 let sub_u64 = library.import(Box::new(Sub));
45 let sub_u32 = library.import(Box::new(SafeSub));
46 let leading_zeros_u64 = library.import(Box::new(LeadingZeros));
47 let add_u32 = library.import(Box::new(SafeAdd));
48 let spilled_denominator_alloc = library.kmalloc(2);
49
50 triton_asm!(
62 {self.entrypoint()}:
65 dup 1
66 dup 1
67 push {spilled_denominator_alloc.write_address()}
68 write_mem 2
69 pop 1
70 dup 3
71 dup 3
72 push 32
73 call {shift_right_u64}
74 swap 1
75 pop 1
76 dup 4
77 dup 4
78 push 00000000004294967295
79 push 0
80 swap 1
81 call {and_u64}
82 swap 1
83 pop 1
84 push {spilled_denominator_alloc.read_address()}
85 read_mem {spilled_denominator_alloc.num_words()}
86 pop 1
87 push 32
88 call {shift_right_u64}
89 swap 1
90 pop 1
91 push {spilled_denominator_alloc.read_address()}
92 read_mem {spilled_denominator_alloc.num_words()}
93 pop 1
94 push 00000000004294967295
95 push 0
96 swap 1
97 call {and_u64}
98 swap 1
99 pop 1
100 push 0
101 push 0
102 push 0
103 push 0
104 dup 11
105 dup 11
106 push {spilled_denominator_alloc.read_address()}
107 read_mem {spilled_denominator_alloc.num_words()}
108 pop 1
109 dup 3
110 dup 3
111 call {lt_u64}
112 push 1
113 swap 1
114 skiz
115 call _binop_Gt_bool_bool_26_then
116 skiz
117 call _binop_Gt_bool_bool_26_else
118 pop 2
119 swap 8
120 pop 1
121 swap 8
122 pop 1
123 swap 8
124 pop 1
125 swap 8
126 pop 5
127 return
128 _binop_Eq_bool_bool_53_then:
129 pop 1
130 dup 8
131 dup 7
132 swap 1
133 div_mod
134 pop 1
135 push 0
136 swap 1
137 dup 10
138 dup 9
139 swap 1
140 div_mod
141 swap 1
142 pop 1
143 push 0
144 swap 1
145 swap 6
146 pop 1
147 swap 6
148 pop 1
149 swap 6
150 pop 1
151 swap 6
152 pop 1
153 push 0
154 return
155 _binop_Eq_bool_bool_53_else:
156 return
157 _binop_Eq_bool_bool_47_then:
158 pop 1
159 dup 1
160 dup 1
161 push 0
162 push 0
163 swap 6
164 pop 1
165 swap 6
166 pop 1
167 swap 6
168 pop 1
169 swap 6
170 pop 1
171 push 0
172 return
173 _binop_Eq_bool_bool_47_else:
174 dup 9
175 push 0
176 eq
177 push 1
178 swap 1
179 skiz
180 call _binop_Eq_bool_bool_53_then
181 skiz
182 call _binop_Eq_bool_bool_53_else
183 return
184 _lit_u64_u64_99_then:
185 pop 1
186 push 0
187 push 0
188 push 0
189 return
190 _lit_u64_u64_99_else:
191 push 00000000004294967295
192 push 00000000004294967295
193 return
194 _binop_Gt_bool_bool_81_while_loop:
195 dup 4
196 push 0
197 lt
198 push 0
199 eq
200 skiz
201 return
202 dup 3
203 dup 3
204 push 1
205 call {shift_left_u64}
206 dup 8
207 dup 8
208 push 63
209 call {shift_right_u64}
210 call {or_u64}
211 swap 4
212 pop 1
213 swap 4
214 pop 1
215 dup 6
216 dup 6
217 push 1
218 call {shift_left_u64}
219 dup 3
220 dup 3
221 push 0
222 push 1
223 call {and_u64}
224 call {or_u64}
225 swap 7
226 pop 1
227 swap 7
228 pop 1
229 push {spilled_denominator_alloc.read_address()}
230 read_mem {spilled_denominator_alloc.num_words()}
231 pop 1
232 dup 5
233 dup 5
234 call {lt_u64}
235 push 1
236 swap 1
237 skiz
238 call _lit_u64_u64_99_then
239 skiz
240 call _lit_u64_u64_99_else
241 swap 2
242 pop 1
243 swap 2
244 pop 1
245 dup 3
246 dup 3
247 push {spilled_denominator_alloc.read_address()}
248 read_mem {spilled_denominator_alloc.num_words()}
249 pop 1
250 dup 5
251 dup 5
252 call {and_u64}
253 swap 3
254 swap 1
255 swap 3
256 swap 2
257 call {sub_u64}
258 swap 4
259 pop 1
260 swap 4
261 pop 1
262 dup 4
263 push 1
264 swap 1
265 call {sub_u32}
266 swap 5
267 pop 1
268 recurse
269 _binop_Or_bool_bool_44_then:
270 pop 1
271 push {spilled_denominator_alloc.read_address()}
272 read_mem {spilled_denominator_alloc.num_words()}
273 pop 1
274 push 0
275 push 1
276 swap 3
277 eq
278 swap 2
279 eq
280 mul
281 push 1
282 swap 1
283 skiz
284 call _binop_Eq_bool_bool_47_then
285 skiz
286 call _binop_Eq_bool_bool_47_else
287 push 0
288 return
289 _binop_Or_bool_bool_44_else:
290 push 0
291 push 0
292 push {spilled_denominator_alloc.read_address()}
293 read_mem {spilled_denominator_alloc.num_words()}
294 pop 1
295 swap 3
296 eq
297 swap 2
298 eq
299 mul
300 push 0
301 eq
302 assert error_id {Self::DIVISION_BY_ZERO_ERROR_ID}
303 push {spilled_denominator_alloc.read_address()}
304 read_mem {spilled_denominator_alloc.num_words()}
305 pop 1
306 call {leading_zeros_u64}
307 dup 2
308 dup 2
309 call {leading_zeros_u64}
310 swap 1
311 call {sub_u32}
312 push 1
313 call {add_u32}
314 dup 2
315 dup 2
316 dup 2
317 call {shift_right_u64}
318 dup 4
319 dup 4
320 push 64
321 dup 5
322 swap 1
323 call {sub_u32}
324 call {shift_left_u64}
325 swap 5
326 pop 1
327 swap 5
328 pop 1
329 push 0
330 push 0
331 call _binop_Gt_bool_bool_81_while_loop
332 dup 6
333 dup 6
334 push 1
335 call {shift_left_u64}
336 dup 3
337 dup 3
338 push 0
339 push 1
340 call {and_u64}
341 call {or_u64}
342 dup 5
343 dup 5
344 swap 11
345 pop 1
346 swap 11
347 pop 1
348 swap 11
349 pop 1
350 swap 11
351 pop 5
352 pop 1
353 return
354 _binop_Gt_bool_bool_26_then:
355 pop 1
356 push 0
357 push 0
358 dup 3
359 dup 3
360 swap 6
361 pop 1
362 swap 6
363 pop 1
364 swap 6
365 pop 1
366 swap 6
367 pop 1
368 push 0
369 return
370 _binop_Gt_bool_bool_26_else:
371 dup 7
372 push 0
373 eq
374 push {spilled_denominator_alloc.read_address()}
375 read_mem {spilled_denominator_alloc.num_words()}
376 pop 1
377 push 0
378 push 1
379 swap 3
380 eq
381 swap 2
382 eq
383 mul
384 add
385 push 2
386 eq
387 dup 8
388 push 0
389 eq
390 dup 11
391 push 0
392 eq
393 add
394 push 2
395 eq
396 add
397 push 0
398 eq
399 push 0
400 eq
401 push 1
402 swap 1
403 skiz
404 call _binop_Or_bool_bool_44_then
405 skiz
406 call _binop_Or_bool_bool_44_else
407 return
408 )
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use crate::library::STATIC_MEMORY_FIRST_ADDRESS;
416 use crate::test_prelude::*;
417
418 impl DivMod {
419 fn set_up_initial_state(&self, numerator: u64, denominator: u64) -> FunctionInitialState {
420 let mut stack = self.init_stack_for_isolated_run();
421 push_encodable(&mut stack, &numerator);
422 push_encodable(&mut stack, &denominator);
423
424 FunctionInitialState {
425 stack,
426 ..Default::default()
427 }
428 }
429 }
430
431 impl Function for DivMod {
432 fn rust_shadow(
433 &self,
434 stack: &mut Vec<BFieldElement>,
435 memory: &mut HashMap<BFieldElement, BFieldElement>,
436 ) {
437 let denominator = pop_encodable::<u64>(stack);
438 let numerator = pop_encodable::<u64>(stack);
439 let quotient = numerator / denominator;
440 let remainder = numerator % denominator;
441 push_encodable(stack, "ient);
442 push_encodable(stack, &remainder);
443
444 encode_to_memory(memory, STATIC_MEMORY_FIRST_ADDRESS - bfe!(1), &denominator);
447 }
448
449 fn pseudorandom_initial_state(
450 &self,
451 seed: [u8; 32],
452 bench_case: Option<BenchmarkCase>,
453 ) -> FunctionInitialState {
454 let (numerator, denominator) = match bench_case {
455 Some(BenchmarkCase::CommonCase) => (u32::MAX.into(), 1 << 15),
456 Some(BenchmarkCase::WorstCase) => (u64::MAX, (1 << 32) + 45454545),
457 None => StdRng::from_seed(seed).random(),
458 };
459
460 self.set_up_initial_state(numerator, denominator)
461 }
462
463 fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
464 const NOISE: u64 = 0x6d26_150f_4669_d677;
465
466 let u64s_of_different_magnitudes = (0..u64::BITS)
467 .step_by(3) .map(|i| 1 << i)
469 .map(|x| x | (x - 1) & NOISE);
470
471 let mut states = u64s_of_different_magnitudes
472 .clone()
473 .cartesian_product(u64s_of_different_magnitudes.clone())
474 .map(|(n, d)| self.set_up_initial_state(n, d))
475 .collect_vec();
476
477 let additional_inputs = [
478 (0, 1),
479 (0, 2),
480 (0, 3),
481 (0, 100),
482 (0, u32::MAX as u64),
483 (0, 0xFFFF_FFFF_0000_0000),
484 (0, 11428751156810088448),
485 (1000, 100),
486 (6098312677908545536, 6098805452391317504),
488 (5373808693584330752, 11428751156810088448),
489 (8268416007396130816, 6204028719464448000),
490 (u64::MAX, 1),
492 (u64::MAX, 2),
493 (u64::MAX, u64::MAX),
494 (0x0000_0001_FFFF_FFFF, 0xFFFF_FFFF_0000_0000),
495 (0xFFFF_FFFF_0000_0000, 0x0000_0000_FFFF_FFFF),
496 (0xABCD_EF12_3456_789A, 0x1234_5678_9ABC_DEF0),
497 (u64::MAX, (1 << 31) + 1),
499 (u64::MAX, (1 << 31) + 454545454),
500 (u64::MAX, (1 << 32) - 1),
501 (u64::MAX, 1 << 32),
502 (u64::MAX, (1 << 32) + 1),
503 (u64::MAX, (1 << 32) + 2),
504 (u64::MAX, (1 << 32) + 3),
505 (u64::MAX, (1 << 32) + 454545454),
506 (u64::MAX, (1 << 33) - 1),
507 (u64::MAX, 1 << 33),
508 (u64::MAX, (1 << 33) + 1),
509 (u64::MAX, (1 << 33) + 454545454),
510 (u64::MAX, (1 << 34) + 454545454),
511 (u64::MAX, (1 << 35) + 454545454),
512 (u64::MAX - 1, (1 << 32) - 2),
513 (u64::MAX - 1, (1 << 32) - 1),
514 (u64::MAX - 1, 1 << 32),
515 (u64::MAX - 1, (1 << 32) + 1),
516 (u64::MAX - 1, (1 << 32) + 2),
517 (u64::MAX - 1, (1 << 32) + 3),
518 (u64::MAX - 1, (1 << 33) - 1),
519 (u64::MAX - 1, 1 << 33),
520 (u64::MAX - 1, (1 << 33) + 1),
521 ];
522
523 states.extend(additional_inputs.map(|(n, d)| self.set_up_initial_state(n, d)));
524 states
525 }
526 }
527
528 #[test]
529 fn rust_shadow() {
530 ShadowedFunction::new(DivMod).test();
531 }
532
533 #[proptest]
534 fn fail_vm_execution_on_divide_by_zero_u32_numerator(numerator: u64) {
535 test_assertion_failure(
536 &ShadowedFunction::new(DivMod),
537 DivMod.set_up_initial_state(numerator, 0).into(),
538 &[DivMod::DIVISION_BY_ZERO_ERROR_ID],
539 );
540 }
541}
542
543#[cfg(test)]
544mod benches {
545 use super::*;
546 use crate::test_prelude::*;
547
548 #[test]
549 fn benchmark() {
550 ShadowedFunction::new(DivMod).bench();
551 }
552}