tasm_lib/arithmetic/u64/
div_mod.rs

1use triton_vm::prelude::*;
2
3use crate::arithmetic::u32::safe_add::SafeAdd;
4use crate::arithmetic::u32::safe_sub::SafeSub;
5use crate::arithmetic::u64::and::And;
6use crate::arithmetic::u64::leading_zeros::LeadingZeros;
7use crate::arithmetic::u64::lt::Lt;
8use crate::arithmetic::u64::or::Or;
9use crate::arithmetic::u64::shift_left::ShiftLeft;
10use crate::arithmetic::u64::shift_right::ShiftRight;
11use crate::arithmetic::u64::sub::Sub;
12use crate::prelude::*;
13
14#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
15pub struct DivMod;
16
17impl DivMod {
18    pub const DIVISION_BY_ZERO_ERROR_ID: i128 = 420;
19}
20
21impl BasicSnippet for DivMod {
22    fn inputs(&self) -> Vec<(DataType, String)> {
23        ["numerator", "denominator"]
24            .map(|name| (DataType::U64, name.to_string()))
25            .to_vec()
26    }
27
28    fn outputs(&self) -> Vec<(DataType, String)> {
29        ["quotient", "remainder"]
30            .map(|name| (DataType::U64, name.to_string()))
31            .to_vec()
32    }
33
34    fn entrypoint(&self) -> String {
35        "tasmlib_arithmetic_u64_div_mod".to_string()
36    }
37
38    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
39        let shift_right_u64 = library.import(Box::new(ShiftRight));
40        let shift_left_u64 = library.import(Box::new(ShiftLeft));
41        let and_u64 = library.import(Box::new(And));
42        let lt_u64 = library.import(Box::new(Lt));
43        let or_u64 = library.import(Box::new(Or));
44        let sub_u64 = library.import(Box::new(Sub));
45        let sub_u32 = library.import(Box::new(SafeSub));
46        let leading_zeros_u64 = library.import(Box::new(LeadingZeros));
47        let add_u32 = library.import(Box::new(SafeAdd));
48        let spilled_denominator_alloc = library.kmalloc(2);
49
50        // The below code has been compiled from a Rust implementation of an LLVM function
51        // called `divmoddi4` that can do u64 divmod with only access to u32 bit divmod and
52        // some u64 arithmetic instructions or functions. The compiler used for this was the
53        // `tasm-lang` compiler: https://github.com/TritonVM/tasm-lang
54        // You could probably get a smaller cycle count if you hand-compiled the function.
55        //
56        // If you do attempt this, check out the following resources:
57        // https://github.com/llvm/llvm-project/compiler-rt/lib/builtins/udivmoddi4.c
58        // which is based on “The PowerPC Compiler Writer’s Guide”
59        // (https://cr.yp.to/2005-590/powerpc-cwg.pdf) section 3.2.3.7:
60        // “32-Bit Implementation of a 64-Bit Unsigned Divide”
61        triton_asm!(
62            // BEFORE: _ numerator_hi numerator_lo denominator_hi denominator_lo
63            // AFTER:  _ quotient_hi quotient_lo remainder_hi remainder_lo
64            {self.entrypoint()}:
65                dup 1
66                dup 1
67                push {spilled_denominator_alloc.write_address()}
68                write_mem 2
69                pop 1
70                dup 3
71                dup 3
72                push 32
73                call {shift_right_u64}
74                swap 1
75                pop 1
76                dup 4
77                dup 4
78                push 00000000004294967295
79                push 0
80                swap 1
81                call {and_u64}
82                swap 1
83                pop 1
84                push {spilled_denominator_alloc.read_address()}
85                read_mem {spilled_denominator_alloc.num_words()}
86                pop 1
87                push 32
88                call {shift_right_u64}
89                swap 1
90                pop 1
91                push {spilled_denominator_alloc.read_address()}
92                read_mem {spilled_denominator_alloc.num_words()}
93                pop 1
94                push 00000000004294967295
95                push 0
96                swap 1
97                call {and_u64}
98                swap 1
99                pop 1
100                push 0
101                push 0
102                push 0
103                push 0
104                dup 11
105                dup 11
106                push {spilled_denominator_alloc.read_address()}
107                read_mem {spilled_denominator_alloc.num_words()}
108                pop 1
109                dup 3
110                dup 3
111                call {lt_u64}
112                push 1
113                swap 1
114                skiz
115                call _binop_Gt_bool_bool_26_then
116                skiz
117                call _binop_Gt_bool_bool_26_else
118                pop 2
119                swap 8
120                pop 1
121                swap 8
122                pop 1
123                swap 8
124                pop 1
125                swap 8
126                pop 5
127                return
128                _binop_Eq_bool_bool_53_then:
129                pop 1
130                dup 8
131                dup 7
132                swap 1
133                div_mod
134                pop 1
135                push 0
136                swap 1
137                dup 10
138                dup 9
139                swap 1
140                div_mod
141                swap 1
142                pop 1
143                push 0
144                swap 1
145                swap 6
146                pop 1
147                swap 6
148                pop 1
149                swap 6
150                pop 1
151                swap 6
152                pop 1
153                push 0
154                return
155                _binop_Eq_bool_bool_53_else:
156                return
157                _binop_Eq_bool_bool_47_then:
158                pop 1
159                dup 1
160                dup 1
161                push 0
162                push 0
163                swap 6
164                pop 1
165                swap 6
166                pop 1
167                swap 6
168                pop 1
169                swap 6
170                pop 1
171                push 0
172                return
173                _binop_Eq_bool_bool_47_else:
174                dup 9
175                push 0
176                eq
177                push 1
178                swap 1
179                skiz
180                call _binop_Eq_bool_bool_53_then
181                skiz
182                call _binop_Eq_bool_bool_53_else
183                return
184                _lit_u64_u64_99_then:
185                pop 1
186                push 0
187                push 0
188                push 0
189                return
190                _lit_u64_u64_99_else:
191                push 00000000004294967295
192                push 00000000004294967295
193                return
194                _binop_Gt_bool_bool_81_while_loop:
195                dup 4
196                push 0
197                lt
198                push 0
199                eq
200                skiz
201                return
202                dup 3
203                dup 3
204                push 1
205                call {shift_left_u64}
206                dup 8
207                dup 8
208                push 63
209                call {shift_right_u64}
210                call {or_u64}
211                swap 4
212                pop 1
213                swap 4
214                pop 1
215                dup 6
216                dup 6
217                push 1
218                call {shift_left_u64}
219                dup 3
220                dup 3
221                push 0
222                push 1
223                call {and_u64}
224                call {or_u64}
225                swap 7
226                pop 1
227                swap 7
228                pop 1
229                push {spilled_denominator_alloc.read_address()}
230                read_mem {spilled_denominator_alloc.num_words()}
231                pop 1
232                dup 5
233                dup 5
234                call {lt_u64}
235                push 1
236                swap 1
237                skiz
238                call _lit_u64_u64_99_then
239                skiz
240                call _lit_u64_u64_99_else
241                swap 2
242                pop 1
243                swap 2
244                pop 1
245                dup 3
246                dup 3
247                push {spilled_denominator_alloc.read_address()}
248                read_mem {spilled_denominator_alloc.num_words()}
249                pop 1
250                dup 5
251                dup 5
252                call {and_u64}
253                swap 3
254                swap 1
255                swap 3
256                swap 2
257                call {sub_u64}
258                swap 4
259                pop 1
260                swap 4
261                pop 1
262                dup 4
263                push 1
264                swap 1
265                call {sub_u32}
266                swap 5
267                pop 1
268                recurse
269                _binop_Or_bool_bool_44_then:
270                pop 1
271                push {spilled_denominator_alloc.read_address()}
272                read_mem {spilled_denominator_alloc.num_words()}
273                pop 1
274                push 0
275                push 1
276                swap 3
277                eq
278                swap 2
279                eq
280                mul
281                push 1
282                swap 1
283                skiz
284                call _binop_Eq_bool_bool_47_then
285                skiz
286                call _binop_Eq_bool_bool_47_else
287                push 0
288                return
289                _binop_Or_bool_bool_44_else:
290                push 0
291                push 0
292                push {spilled_denominator_alloc.read_address()}
293                read_mem {spilled_denominator_alloc.num_words()}
294                pop 1
295                swap 3
296                eq
297                swap 2
298                eq
299                mul
300                push 0
301                eq
302                assert error_id {Self::DIVISION_BY_ZERO_ERROR_ID}
303                push {spilled_denominator_alloc.read_address()}
304                read_mem {spilled_denominator_alloc.num_words()}
305                pop 1
306                call {leading_zeros_u64}
307                dup 2
308                dup 2
309                call {leading_zeros_u64}
310                swap 1
311                call {sub_u32}
312                push 1
313                call {add_u32}
314                dup 2
315                dup 2
316                dup 2
317                call {shift_right_u64}
318                dup 4
319                dup 4
320                push 64
321                dup 5
322                swap 1
323                call {sub_u32}
324                call {shift_left_u64}
325                swap 5
326                pop 1
327                swap 5
328                pop 1
329                push 0
330                push 0
331                call _binop_Gt_bool_bool_81_while_loop
332                dup 6
333                dup 6
334                push 1
335                call {shift_left_u64}
336                dup 3
337                dup 3
338                push 0
339                push 1
340                call {and_u64}
341                call {or_u64}
342                dup 5
343                dup 5
344                swap 11
345                pop 1
346                swap 11
347                pop 1
348                swap 11
349                pop 1
350                swap 11
351                pop 5
352                pop 1
353                return
354                _binop_Gt_bool_bool_26_then:
355                pop 1
356                push 0
357                push 0
358                dup 3
359                dup 3
360                swap 6
361                pop 1
362                swap 6
363                pop 1
364                swap 6
365                pop 1
366                swap 6
367                pop 1
368                push 0
369                return
370                _binop_Gt_bool_bool_26_else:
371                dup 7
372                push 0
373                eq
374                push {spilled_denominator_alloc.read_address()}
375                read_mem {spilled_denominator_alloc.num_words()}
376                pop 1
377                push 0
378                push 1
379                swap 3
380                eq
381                swap 2
382                eq
383                mul
384                add
385                push 2
386                eq
387                dup 8
388                push 0
389                eq
390                dup 11
391                push 0
392                eq
393                add
394                push 2
395                eq
396                add
397                push 0
398                eq
399                push 0
400                eq
401                push 1
402                swap 1
403                skiz
404                call _binop_Or_bool_bool_44_then
405                skiz
406                call _binop_Or_bool_bool_44_else
407                return
408        )
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use crate::library::STATIC_MEMORY_FIRST_ADDRESS;
416    use crate::test_prelude::*;
417
418    impl DivMod {
419        fn set_up_initial_state(&self, numerator: u64, denominator: u64) -> FunctionInitialState {
420            let mut stack = self.init_stack_for_isolated_run();
421            push_encodable(&mut stack, &numerator);
422            push_encodable(&mut stack, &denominator);
423
424            FunctionInitialState {
425                stack,
426                ..Default::default()
427            }
428        }
429    }
430
431    impl Function for DivMod {
432        fn rust_shadow(
433            &self,
434            stack: &mut Vec<BFieldElement>,
435            memory: &mut HashMap<BFieldElement, BFieldElement>,
436        ) {
437            let denominator = pop_encodable::<u64>(stack);
438            let numerator = pop_encodable::<u64>(stack);
439            let quotient = numerator / denominator;
440            let remainder = numerator % denominator;
441            push_encodable(stack, &quotient);
442            push_encodable(stack, &remainder);
443
444            // Accomodate spilling. This could probably be avoided if the code was compiled
445            // by hand instead.
446            encode_to_memory(memory, STATIC_MEMORY_FIRST_ADDRESS - bfe!(1), &denominator);
447        }
448
449        fn pseudorandom_initial_state(
450            &self,
451            seed: [u8; 32],
452            bench_case: Option<BenchmarkCase>,
453        ) -> FunctionInitialState {
454            let (numerator, denominator) = match bench_case {
455                Some(BenchmarkCase::CommonCase) => (u32::MAX.into(), 1 << 15),
456                Some(BenchmarkCase::WorstCase) => (u64::MAX, (1 << 32) + 45454545),
457                None => StdRng::from_seed(seed).random(),
458            };
459
460            self.set_up_initial_state(numerator, denominator)
461        }
462
463        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
464            const NOISE: u64 = 0x6d26_150f_4669_d677;
465
466            let u64s_of_different_magnitudes = (0..u64::BITS)
467                .step_by(3) // test performance is atrocious otherwise
468                .map(|i| 1 << i)
469                .map(|x| x | (x - 1) & NOISE);
470
471            let mut states = u64s_of_different_magnitudes
472                .clone()
473                .cartesian_product(u64s_of_different_magnitudes.clone())
474                .map(|(n, d)| self.set_up_initial_state(n, d))
475                .collect_vec();
476
477            let additional_inputs = [
478                (0, 1),
479                (0, 2),
480                (0, 3),
481                (0, 100),
482                (0, u32::MAX as u64),
483                (0, 0xFFFF_FFFF_0000_0000),
484                (0, 11428751156810088448),
485                (1000, 100),
486                // found in bug reports online
487                (6098312677908545536, 6098805452391317504),
488                (5373808693584330752, 11428751156810088448),
489                (8268416007396130816, 6204028719464448000),
490                // suggested by an LLM
491                (u64::MAX, 1),
492                (u64::MAX, 2),
493                (u64::MAX, u64::MAX),
494                (0x0000_0001_FFFF_FFFF, 0xFFFF_FFFF_0000_0000),
495                (0xFFFF_FFFF_0000_0000, 0x0000_0000_FFFF_FFFF),
496                (0xABCD_EF12_3456_789A, 0x1234_5678_9ABC_DEF0),
497                // edge cases around powers of two
498                (u64::MAX, (1 << 31) + 1),
499                (u64::MAX, (1 << 31) + 454545454),
500                (u64::MAX, (1 << 32) - 1),
501                (u64::MAX, 1 << 32),
502                (u64::MAX, (1 << 32) + 1),
503                (u64::MAX, (1 << 32) + 2),
504                (u64::MAX, (1 << 32) + 3),
505                (u64::MAX, (1 << 32) + 454545454),
506                (u64::MAX, (1 << 33) - 1),
507                (u64::MAX, 1 << 33),
508                (u64::MAX, (1 << 33) + 1),
509                (u64::MAX, (1 << 33) + 454545454),
510                (u64::MAX, (1 << 34) + 454545454),
511                (u64::MAX, (1 << 35) + 454545454),
512                (u64::MAX - 1, (1 << 32) - 2),
513                (u64::MAX - 1, (1 << 32) - 1),
514                (u64::MAX - 1, 1 << 32),
515                (u64::MAX - 1, (1 << 32) + 1),
516                (u64::MAX - 1, (1 << 32) + 2),
517                (u64::MAX - 1, (1 << 32) + 3),
518                (u64::MAX - 1, (1 << 33) - 1),
519                (u64::MAX - 1, 1 << 33),
520                (u64::MAX - 1, (1 << 33) + 1),
521            ];
522
523            states.extend(additional_inputs.map(|(n, d)| self.set_up_initial_state(n, d)));
524            states
525        }
526    }
527
528    #[test]
529    fn rust_shadow() {
530        ShadowedFunction::new(DivMod).test();
531    }
532
533    #[proptest]
534    fn fail_vm_execution_on_divide_by_zero_u32_numerator(numerator: u64) {
535        test_assertion_failure(
536            &ShadowedFunction::new(DivMod),
537            DivMod.set_up_initial_state(numerator, 0).into(),
538            &[DivMod::DIVISION_BY_ZERO_ERROR_ID],
539        );
540    }
541}
542
543#[cfg(test)]
544mod benches {
545    use super::*;
546    use crate::test_prelude::*;
547
548    #[test]
549    fn benchmark() {
550        ShadowedFunction::new(DivMod).bench();
551    }
552}