Skip to main content

tasm_lib/arithmetic/u64/
div_mod.rs

1use triton_vm::prelude::*;
2
3use crate::arithmetic::u32::safe_add::SafeAdd;
4use crate::arithmetic::u32::safe_sub::SafeSub;
5use crate::arithmetic::u64::and::And;
6use crate::arithmetic::u64::leading_zeros::LeadingZeros;
7use crate::arithmetic::u64::lt::Lt;
8use crate::arithmetic::u64::or::Or;
9use crate::arithmetic::u64::shift_left::ShiftLeft;
10use crate::arithmetic::u64::shift_right::ShiftRight;
11use crate::arithmetic::u64::sub::Sub;
12use crate::prelude::*;
13
14#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash)]
15pub struct DivMod;
16
17impl DivMod {
18    pub const DIVISION_BY_ZERO_ERROR_ID: i128 = 420;
19}
20
21impl BasicSnippet for DivMod {
22    fn parameters(&self) -> Vec<(DataType, String)> {
23        ["numerator", "denominator"]
24            .map(|name| (DataType::U64, name.to_string()))
25            .to_vec()
26    }
27
28    fn return_values(&self) -> Vec<(DataType, String)> {
29        ["quotient", "remainder"]
30            .map(|name| (DataType::U64, name.to_string()))
31            .to_vec()
32    }
33
34    fn entrypoint(&self) -> String {
35        "tasmlib_arithmetic_u64_div_mod".to_string()
36    }
37
38    fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
39        let shift_right_u64 = library.import(Box::new(ShiftRight));
40        let shift_left_u64 = library.import(Box::new(ShiftLeft));
41        let and_u64 = library.import(Box::new(And));
42        let lt_u64 = library.import(Box::new(Lt));
43        let or_u64 = library.import(Box::new(Or));
44        let sub_u64 = library.import(Box::new(Sub));
45        let sub_u32 = library.import(Box::new(SafeSub));
46        let leading_zeros_u64 = library.import(Box::new(LeadingZeros));
47        let add_u32 = library.import(Box::new(SafeAdd));
48        let spilled_denominator_alloc = library.kmalloc(2);
49
50        // The below code has been compiled from a Rust implementation of an LLVM function
51        // called `divmoddi4` that can do u64 divmod with only access to u32 bit divmod and
52        // some u64 arithmetic instructions or functions. The compiler used for this was the
53        // `tasm-lang` compiler: https://github.com/TritonVM/tasm-lang
54        // You could probably get a smaller cycle count if you hand-compiled the function.
55        //
56        // If you do attempt this, check out the following resources:
57        // https://github.com/llvm/llvm-project/compiler-rt/lib/builtins/udivmoddi4.c
58        // which is based on “The PowerPC Compiler Writer’s Guide”
59        // (https://cr.yp.to/2005-590/powerpc-cwg.pdf) section 3.2.3.7:
60        // “32-Bit Implementation of a 64-Bit Unsigned Divide”
61        triton_asm!(
62            // BEFORE: _ numerator_hi numerator_lo denominator_hi denominator_lo
63            // AFTER:  _ quotient_hi quotient_lo remainder_hi remainder_lo
64            {self.entrypoint()}:
65                dup 1
66                dup 1
67                push {spilled_denominator_alloc.write_address()}
68                write_mem 2
69                pop 1
70                dup 3
71                dup 3
72                push 32
73                call {shift_right_u64}
74                swap 1
75                pop 1
76                dup 4
77                dup 4
78                push 00000000004294967295
79                push 0
80                swap 1
81                call {and_u64}
82                swap 1
83                pop 1
84                push {spilled_denominator_alloc.read_address()}
85                read_mem {spilled_denominator_alloc.num_words()}
86                pop 1
87                push 32
88                call {shift_right_u64}
89                swap 1
90                pop 1
91                push {spilled_denominator_alloc.read_address()}
92                read_mem {spilled_denominator_alloc.num_words()}
93                pop 1
94                push 00000000004294967295
95                push 0
96                swap 1
97                call {and_u64}
98                swap 1
99                pop 1
100                push 0
101                push 0
102                push 0
103                push 0
104                dup 11
105                dup 11
106                push {spilled_denominator_alloc.read_address()}
107                read_mem {spilled_denominator_alloc.num_words()}
108                pop 1
109                dup 3
110                dup 3
111                call {lt_u64}
112                push 1
113                swap 1
114                skiz
115                call _binop_Gt_bool_bool_26_then
116                skiz
117                call _binop_Gt_bool_bool_26_else
118                pop 2
119                swap 8
120                pop 1
121                swap 8
122                pop 1
123                swap 8
124                pop 1
125                swap 8
126                pop 5
127                return
128                _binop_Eq_bool_bool_53_then:
129                pop 1
130                dup 8
131                dup 7
132                swap 1
133                div_mod
134                pop 1
135                push 0
136                swap 1
137                dup 10
138                dup 9
139                swap 1
140                div_mod
141                swap 1
142                pop 1
143                push 0
144                swap 1
145                swap 6
146                pop 1
147                swap 6
148                pop 1
149                swap 6
150                pop 1
151                swap 6
152                pop 1
153                push 0
154                return
155                _binop_Eq_bool_bool_53_else:
156                return
157                _binop_Eq_bool_bool_47_then:
158                pop 1
159                dup 1
160                dup 1
161                push 0
162                push 0
163                swap 6
164                pop 1
165                swap 6
166                pop 1
167                swap 6
168                pop 1
169                swap 6
170                pop 1
171                push 0
172                return
173                _binop_Eq_bool_bool_47_else:
174                dup 9
175                push 0
176                eq
177                push 1
178                swap 1
179                skiz
180                call _binop_Eq_bool_bool_53_then
181                skiz
182                call _binop_Eq_bool_bool_53_else
183                return
184                _lit_u64_u64_99_then:
185                pop 1
186                push 0
187                push 0
188                push 0
189                return
190                _lit_u64_u64_99_else:
191                push 00000000004294967295
192                push 00000000004294967295
193                return
194                _binop_Gt_bool_bool_81_while_loop:
195                dup 4
196                push 0
197                lt
198                push 0
199                eq
200                skiz
201                return
202                dup 3
203                dup 3
204                push 1
205                call {shift_left_u64}
206                dup 8
207                dup 8
208                push 63
209                call {shift_right_u64}
210                call {or_u64}
211                swap 4
212                pop 1
213                swap 4
214                pop 1
215                dup 6
216                dup 6
217                push 1
218                call {shift_left_u64}
219                dup 3
220                dup 3
221                push 0
222                push 1
223                call {and_u64}
224                call {or_u64}
225                swap 7
226                pop 1
227                swap 7
228                pop 1
229                push {spilled_denominator_alloc.read_address()}
230                read_mem {spilled_denominator_alloc.num_words()}
231                pop 1
232                dup 5
233                dup 5
234                call {lt_u64}
235                push 1
236                swap 1
237                skiz
238                call _lit_u64_u64_99_then
239                skiz
240                call _lit_u64_u64_99_else
241                swap 2
242                pop 1
243                swap 2
244                pop 1
245                dup 3
246                dup 3
247                push {spilled_denominator_alloc.read_address()}
248                read_mem {spilled_denominator_alloc.num_words()}
249                pop 1
250                dup 5
251                dup 5
252                call {and_u64}
253                swap 3
254                swap 1
255                swap 3
256                swap 2
257                call {sub_u64}
258                swap 4
259                pop 1
260                swap 4
261                pop 1
262                dup 4
263                push 1
264                swap 1
265                call {sub_u32}
266                swap 5
267                pop 1
268                recurse
269                _binop_Or_bool_bool_44_then:
270                pop 1
271                push {spilled_denominator_alloc.read_address()}
272                read_mem {spilled_denominator_alloc.num_words()}
273                pop 1
274                push 0
275                push 1
276                swap 3
277                eq
278                swap 2
279                eq
280                mul
281                push 1
282                swap 1
283                skiz
284                call _binop_Eq_bool_bool_47_then
285                skiz
286                call _binop_Eq_bool_bool_47_else
287                push 0
288                return
289                _binop_Or_bool_bool_44_else:
290                push 0
291                push 0
292                push {spilled_denominator_alloc.read_address()}
293                read_mem {spilled_denominator_alloc.num_words()}
294                pop 1
295                swap 3
296                eq
297                swap 2
298                eq
299                mul
300                push 0
301                eq
302                assert error_id {Self::DIVISION_BY_ZERO_ERROR_ID}
303                push {spilled_denominator_alloc.read_address()}
304                read_mem {spilled_denominator_alloc.num_words()}
305                pop 1
306                call {leading_zeros_u64}
307                dup 2
308                dup 2
309                call {leading_zeros_u64}
310                swap 1
311                call {sub_u32}
312                push 1
313                call {add_u32}
314                dup 2
315                dup 2
316                dup 2
317                call {shift_right_u64}
318                dup 4
319                dup 4
320                push 64
321                dup 5
322                swap 1
323                call {sub_u32}
324                call {shift_left_u64}
325                swap 5
326                pop 1
327                swap 5
328                pop 1
329                push 0
330                push 0
331                call _binop_Gt_bool_bool_81_while_loop
332                dup 6
333                dup 6
334                push 1
335                call {shift_left_u64}
336                dup 3
337                dup 3
338                push 0
339                push 1
340                call {and_u64}
341                call {or_u64}
342                dup 5
343                dup 5
344                swap 11
345                pop 1
346                swap 11
347                pop 1
348                swap 11
349                pop 1
350                swap 11
351                pop 5
352                pop 1
353                return
354                _binop_Gt_bool_bool_26_then:
355                pop 1
356                push 0
357                push 0
358                dup 3
359                dup 3
360                swap 6
361                pop 1
362                swap 6
363                pop 1
364                swap 6
365                pop 1
366                swap 6
367                pop 1
368                push 0
369                return
370                _binop_Gt_bool_bool_26_else:
371                dup 7
372                push 0
373                eq
374                push {spilled_denominator_alloc.read_address()}
375                read_mem {spilled_denominator_alloc.num_words()}
376                pop 1
377                push 0
378                push 1
379                swap 3
380                eq
381                swap 2
382                eq
383                mul
384                add
385                push 2
386                eq
387                dup 8
388                push 0
389                eq
390                dup 11
391                push 0
392                eq
393                add
394                push 2
395                eq
396                add
397                push 0
398                eq
399                push 0
400                eq
401                push 1
402                swap 1
403                skiz
404                call _binop_Or_bool_bool_44_then
405                skiz
406                call _binop_Or_bool_bool_44_else
407                return
408        )
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use crate::library::STATIC_MEMORY_FIRST_ADDRESS;
416    use crate::test_prelude::*;
417
418    impl DivMod {
419        fn set_up_initial_state(&self, numerator: u64, denominator: u64) -> FunctionInitialState {
420            let mut stack = self.init_stack_for_isolated_run();
421            push_encodable(&mut stack, &numerator);
422            push_encodable(&mut stack, &denominator);
423
424            FunctionInitialState {
425                stack,
426                ..Default::default()
427            }
428        }
429    }
430
431    impl Function for DivMod {
432        fn rust_shadow(
433            &self,
434            stack: &mut Vec<BFieldElement>,
435            memory: &mut HashMap<BFieldElement, BFieldElement>,
436        ) -> Result<(), RustShadowError> {
437            let denominator = pop_encodable::<u64>(stack)?;
438            let numerator = pop_encodable::<u64>(stack)?;
439            if denominator == 0 {
440                return Err(RustShadowError::Other);
441            }
442
443            let quotient = numerator / denominator;
444            let remainder = numerator % denominator;
445            push_encodable(stack, &quotient);
446            push_encodable(stack, &remainder);
447
448            // Accomodate spilling. This could probably be avoided if the code was compiled
449            // by hand instead.
450            encode_to_memory(memory, STATIC_MEMORY_FIRST_ADDRESS - bfe!(1), &denominator);
451            Ok(())
452        }
453
454        fn pseudorandom_initial_state(
455            &self,
456            seed: [u8; 32],
457            bench_case: Option<BenchmarkCase>,
458        ) -> FunctionInitialState {
459            let (numerator, denominator) = match bench_case {
460                Some(BenchmarkCase::CommonCase) => (u32::MAX.into(), 1 << 15),
461                Some(BenchmarkCase::WorstCase) => (u64::MAX, (1 << 32) + 45454545),
462                None => StdRng::from_seed(seed).random(),
463            };
464
465            self.set_up_initial_state(numerator, denominator)
466        }
467
468        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
469            const NOISE: u64 = 0x6d26_150f_4669_d677;
470
471            let u64s_of_different_magnitudes = (0..u64::BITS)
472                .step_by(3) // test performance is atrocious otherwise
473                .map(|i| 1 << i)
474                .map(|x| x | (x - 1) & NOISE);
475
476            let mut states = u64s_of_different_magnitudes
477                .clone()
478                .cartesian_product(u64s_of_different_magnitudes.clone())
479                .map(|(n, d)| self.set_up_initial_state(n, d))
480                .collect_vec();
481
482            let additional_inputs = [
483                (0, 1),
484                (0, 2),
485                (0, 3),
486                (0, 100),
487                (0, u32::MAX as u64),
488                (0, 0xFFFF_FFFF_0000_0000),
489                (0, 11428751156810088448),
490                (1000, 100),
491                // found in bug reports online
492                (6098312677908545536, 6098805452391317504),
493                (5373808693584330752, 11428751156810088448),
494                (8268416007396130816, 6204028719464448000),
495                // suggested by an LLM
496                (u64::MAX, 1),
497                (u64::MAX, 2),
498                (u64::MAX, u64::MAX),
499                (0x0000_0001_FFFF_FFFF, 0xFFFF_FFFF_0000_0000),
500                (0xFFFF_FFFF_0000_0000, 0x0000_0000_FFFF_FFFF),
501                (0xABCD_EF12_3456_789A, 0x1234_5678_9ABC_DEF0),
502                // edge cases around powers of two
503                (u64::MAX, (1 << 31) + 1),
504                (u64::MAX, (1 << 31) + 454545454),
505                (u64::MAX, (1 << 32) - 1),
506                (u64::MAX, 1 << 32),
507                (u64::MAX, (1 << 32) + 1),
508                (u64::MAX, (1 << 32) + 2),
509                (u64::MAX, (1 << 32) + 3),
510                (u64::MAX, (1 << 32) + 454545454),
511                (u64::MAX, (1 << 33) - 1),
512                (u64::MAX, 1 << 33),
513                (u64::MAX, (1 << 33) + 1),
514                (u64::MAX, (1 << 33) + 454545454),
515                (u64::MAX, (1 << 34) + 454545454),
516                (u64::MAX, (1 << 35) + 454545454),
517                (u64::MAX - 1, (1 << 32) - 2),
518                (u64::MAX - 1, (1 << 32) - 1),
519                (u64::MAX - 1, 1 << 32),
520                (u64::MAX - 1, (1 << 32) + 1),
521                (u64::MAX - 1, (1 << 32) + 2),
522                (u64::MAX - 1, (1 << 32) + 3),
523                (u64::MAX - 1, (1 << 33) - 1),
524                (u64::MAX - 1, 1 << 33),
525                (u64::MAX - 1, (1 << 33) + 1),
526            ];
527
528            states.extend(additional_inputs.map(|(n, d)| self.set_up_initial_state(n, d)));
529            states
530        }
531    }
532
533    #[macro_rules_attr::apply(test)]
534    fn rust_shadow() {
535        ShadowedFunction::new(DivMod).test();
536    }
537
538    #[macro_rules_attr::apply(proptest)]
539    fn fail_vm_execution_on_divide_by_zero_u32_numerator(numerator: u64) {
540        test_assertion_failure(
541            &ShadowedFunction::new(DivMod),
542            DivMod.set_up_initial_state(numerator, 0).into(),
543            &[DivMod::DIVISION_BY_ZERO_ERROR_ID],
544        );
545    }
546}
547
548#[cfg(test)]
549mod benches {
550    use super::*;
551    use crate::test_prelude::*;
552
553    #[macro_rules_attr::apply(test)]
554    fn benchmark() {
555        ShadowedFunction::new(DivMod).bench();
556    }
557}