tasm_lib/arithmetic/u64/
overflowing_sub.rs

1use std::collections::HashMap;
2
3use triton_vm::prelude::*;
4
5use crate::prelude::*;
6use crate::traits::basic_snippet::Reviewer;
7use crate::traits::basic_snippet::SignOffFingerprint;
8
9/// [Overflowing subtraction][sub] for unsigned 64-bit integers.
10///
11/// # Behavior
12///
13/// ```text
14/// BEFORE: _ [subtrahend: u64] [minuend: u64]
15/// AFTER:  _ [difference: u64] [is_overflow: bool]
16/// ```
17///
18/// # Preconditions
19///
20/// - all input arguments are properly [`BFieldCodec`] encoded
21///
22/// # Postconditions
23///
24/// - the output `difference` is the `minuend` minus the `subtrahend`
25/// - the output `is_overflow` is `true` if and only if the minuend is greater
26///   than the subtrahend
27/// - the output is properly [`BFieldCodec`] encoded
28///
29/// [sub]: u64::overflowing_sub
30#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
31pub struct OverflowingSub;
32
33impl OverflowingSub {
34    /// The code shared between [`crate::arithmetic::u64::sub::Sub`],
35    /// [`crate::arithmetic::u64::wrapping_sub::WrappingSub`], and
36    /// [`OverflowingSub`]. Take care to treat the `difference_hi` correctly,
37    /// depending on how you want to handle overflow.
38    ///
39    /// ```text
40    /// BEFORE: _ subtrahend_hi subtrahend_lo minuend_hi minuend_lo
41    /// AFTER:  _ difference_lo (minuend_hi - subtrahend_hi - carry)
42    /// ```
43    pub(crate) fn common_subtraction_code() -> Vec<LabelledInstruction> {
44        triton_asm! {
45            // BEFORE: _ subtrahend_hi subtrahend_lo minuend_hi minuend_lo
46            // AFTER:  _ difference_hi difference_lo is_overflow
47            pick 2
48            // _ subtrahend_hi minuend_hi minuend_lo subtrahend_lo
49
50            push -1
51            mul
52            add
53            // _ subtrahend_hi minuend_hi (minuend_lo - subtrahend_lo)
54
55            /* Any overflow manifests in the high limb. By adding 2^32, this high limb
56             * is “pushed back” to be either 0 or 1; 1 in the case where _no_ overflow
57             * has occurred, and 0 if overflow has occurred.
58             *
59             * To be honest, I don't fully understand all the subtlety going on here.
60             * However, all the edge cases that I have identified pass all the tests,
61             * indicating that things are fine. 👍
62             */
63            addi {1_u64 << 32}
64            split
65            // _ subtrahend_hi minuend_hi !carry difference_lo
66
67            place 3
68            // _ difference_lo subtrahend_hi minuend_hi !carry
69
70            push 0
71            eq
72            // _ difference_lo subtrahend_hi minuend_hi carry
73
74            pick 2
75            add
76            // _ difference_lo minuend_hi (subtrahend_hi + carry)
77
78            push -1
79            mul
80            add
81            // _ difference_lo (minuend_hi - subtrahend_hi - carry)
82        }
83    }
84}
85
86impl BasicSnippet for OverflowingSub {
87    fn inputs(&self) -> Vec<(DataType, String)> {
88        ["subtrahend", "minuend"]
89            .map(|s| (DataType::U64, s.to_string()))
90            .to_vec()
91    }
92
93    fn outputs(&self) -> Vec<(DataType, String)> {
94        vec![
95            (DataType::U64, "wrapped_diff".to_string()),
96            (DataType::Bool, "is_overflow".to_string()),
97        ]
98    }
99
100    fn entrypoint(&self) -> String {
101        "tasmlib_arithmetic_u64_overflowing_sub".to_string()
102    }
103
104    fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
105        triton_asm!(
106            {self.entrypoint()}:
107                {&Self::common_subtraction_code()}
108                // _ difference_lo (minuend_hi - subtrahend_hi - carry)
109
110                addi {1_u64 << 32}
111                split
112                // _ difference_lo !is_overflow difference_hi
113
114                place 2
115                // _ difference_hi difference_lo !is_overflow
116
117                push 0
118                eq
119                // _ difference_hi difference_lo is_overflow
120
121                return
122        )
123    }
124
125    fn sign_offs(&self) -> HashMap<Reviewer, SignOffFingerprint> {
126        let mut sign_offs = HashMap::new();
127        sign_offs.insert(Reviewer("ferdinand"), 0x4e4c796ae06e4400.into());
128        sign_offs
129    }
130}
131
132#[cfg(test)]
133pub(crate) mod tests {
134    use super::*;
135    use crate::test_prelude::*;
136
137    impl OverflowingSub {
138        pub fn edge_case_values() -> Vec<u64> {
139            let wiggle_edge_case_point = |p: u64| {
140                [
141                    p.checked_sub(3),
142                    p.checked_sub(2),
143                    p.checked_sub(1),
144                    Some(p),
145                    p.checked_add(1),
146                    p.checked_add(2),
147                    p.checked_add(3),
148                ]
149            };
150
151            [1, 1 << 32, 1 << 33, 1 << 34, 1 << 40, 1 << 63, u64::MAX]
152                .into_iter()
153                .flat_map(wiggle_edge_case_point)
154                .flatten()
155                .collect()
156        }
157    }
158
159    impl Closure for OverflowingSub {
160        type Args = (u64, u64);
161
162        fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
163            let (subtrahend, minuend) = pop_encodable::<Self::Args>(stack);
164            push_encodable(stack, &minuend.overflowing_sub(subtrahend));
165        }
166
167        fn pseudorandom_args(
168            &self,
169            seed: [u8; 32],
170            bench_case: Option<BenchmarkCase>,
171        ) -> Self::Args {
172            match bench_case {
173                Some(BenchmarkCase::CommonCase) => ((1 << 63) - 1, 1 << 63),
174                Some(BenchmarkCase::WorstCase) => (1 << 50, 1 << 63),
175                None => StdRng::from_seed(seed).random(),
176            }
177        }
178
179        fn corner_case_args(&self) -> Vec<Self::Args> {
180            let edge_case_values = Self::edge_case_values();
181
182            edge_case_values
183                .iter()
184                .cartesian_product(&edge_case_values)
185                .map(|(&subtrahend, &minuend)| (subtrahend, minuend))
186                .collect()
187        }
188    }
189
190    #[test]
191    fn rust_shadow() {
192        ShadowedClosure::new(OverflowingSub).test()
193    }
194}
195
196#[cfg(test)]
197mod benches {
198    use super::*;
199    use crate::test_prelude::*;
200
201    #[test]
202    fn benchmark() {
203        ShadowedClosure::new(OverflowingSub).bench()
204    }
205}