Skip to main content

tasm_lib/hashing/
lt_digest.rs

1use tip5::Digest;
2use triton_vm::prelude::*;
3
4use crate::prelude::*;
5
6/// Compare two digests
7#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
8pub struct LtDigest;
9
10impl BasicSnippet for LtDigest {
11    fn parameters(&self) -> Vec<(DataType, String)> {
12        vec![
13            (DataType::VoidPointer, "*digest_lhs".to_owned()),
14            (DataType::VoidPointer, "*digest_rhs".to_owned()),
15        ]
16    }
17
18    fn return_values(&self) -> Vec<(DataType, String)> {
19        vec![(DataType::Bool, "digest_lhs > digest_rhs".to_owned())]
20    }
21
22    fn entrypoint(&self) -> String {
23        "tasmlib_hashing_lt_digest".to_owned()
24    }
25
26    fn code(&self, _library: &mut Library) -> Vec<LabelledInstruction> {
27        let entrypoint = self.entrypoint();
28
29        let loop_label = format!("{entrypoint}_short_circuiting_loop");
30
31        // result ∈ {0, 1} : {0: !(lhs > rhs), 1: lhs > rhs}
32        let compare_bfes_loop = triton_asm!(
33            // Invariant: _ result (*lhs - 1) *lhs[i] *rhs[i] g0 g1
34            {loop_label}:
35                hint result: bool = stack[5]
36                hint lhs_end_condition: BFieldElement = stack[4]
37                hint lhs_i_ptr = stack[3]
38                hint rhs_i_ptr = stack[2]
39
40                /* Check if we are done */
41                dup 4
42                dup 4
43                eq
44                skiz return
45
46                pop 2
47                // _ result (*lhs - 1) *lhs[i] *rhs[i]
48
49                read_mem 1
50                // _ result (*lhs - 1) *lhs[i] rhs[i] *rhs[i-1]
51
52                swap 2
53                read_mem 1
54                // _ result (*lhs - 1) *rhs[i-1] rhs[i] lhs[i] *lhs[i-1]
55
56                swap 3
57                swap 2
58                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] lhs[i] rhs[i]
59
60                dup 1 dup 1
61                eq
62                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] lhs[i] rhs[i] (lhs[i] == rhs[i])
63
64                skiz recurse
65                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] lhs[i] rhs[i]
66
67                split
68                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] lhs[i] rhs[i]_hi rhs[i]_lo
69
70                swap 2
71                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] rhs[i]_lo rhs[i]_hi lhs[i]
72
73                split
74                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] rhs[i]_lo rhs[i]_hi lhs[i]_hi lhs[i]_lo
75
76                /* Calculate rhs[i]_hi < lhs[i]_hi || rhs[i]_hi == rhs[i]_hi && rhs[i]_lo < lhs[i]_lo*/
77
78                dup 1 dup 3 lt
79                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] rhs[i]_lo rhs[i]_hi lhs[i]_hi lhs[i]_lo (lhs[i]_hi > rhs[i]_hi)
80
81                swap 4
82                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] (lhs[i]_hi > rhs[i]_hi) rhs[i]_hi lhs[i]_hi lhs[i]_lo rhs[i]_lo
83
84                lt
85                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] (lhs[i]_hi > rhs[i]_hi) rhs[i]_hi lhs[i]_hi (lhs[i]_lo > rhs[i]_lo)
86
87                swap 2 eq
88                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] (lhs[i]_hi > rhs[i]_hi) (lhs[i]_lo > rhs[i]_lo) (lhs[i]_hi == rhs[i]_hi)
89
90                mul
91                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] (lhs[i]_hi > rhs[i]_hi) ((lhs[i]_lo > rhs[i]_lo) && (lhs[i]_hi == rhs[i]_hi))
92
93                add
94                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] (lhs[i]_hi > rhs[i]_hi) || ((lhs[i]_lo > rhs[i]_lo) && (lhs[i]_hi == rhs[i]_hi))
95                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] result'
96
97                swap 4
98                push 0
99                // _ result' (*lhs - 1) *lhs[i-1] *rhs[i-1] g0 g1
100
101                return
102        );
103
104        triton_asm!(
105            // BEFORE: _ *lhs *rhs
106            // AFTER:  _ (lhs > rhs)
107            {entrypoint}:
108                // _ *lhs *rhs
109
110                // goal:  _ result (*lhs - 1) *lhs[i] *rhs[i] g0 g1
111
112                push 0
113                swap 2
114                // _ 0 *rhs *lhs
115
116                push {Digest::LEN - 1} add
117                // _ 0 *rhs *lhs[4]
118
119                dup 0 push {-(Digest::LEN as isize)} add
120                // _ 0 *rhs *lhs[4] (*lhs - 1)
121
122                swap 2
123                // _ 0 (*lhs - 1) *lhs[4] *rhs
124
125                push {Digest::LEN - 1} add
126                // _ 0 (*lhs - 1) *lhs[4] *rhs[4]
127
128                push 0
129                push 0
130                // _ 0 (*lhs - 1) *lhs[4] *rhs[4] 0 0
131                // _ result (*lhs - 1) *lhs[4] *rhs[4] g0 g1
132
133                call {loop_label}
134                // _ result (*lhs - 1) *lhs[i] *rhs[i] g0 g1
135
136                pop 5
137                // _ result
138
139                return
140
141            {&compare_bfes_loop}
142        )
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use crate::memory::encode_to_memory;
150    use crate::test_prelude::*;
151
152    impl LtDigest {
153        fn prepare_state(
154            &self,
155            lhs_ptr: BFieldElement,
156            rhs_ptr: BFieldElement,
157            lhs: Digest,
158            rhs: Digest,
159        ) -> FunctionInitialState {
160            let mut memory = HashMap::default();
161            encode_to_memory(&mut memory, lhs_ptr, &lhs);
162            encode_to_memory(&mut memory, rhs_ptr, &rhs);
163
164            let stack = [self.init_stack_for_isolated_run(), vec![lhs_ptr, rhs_ptr]].concat();
165
166            FunctionInitialState { stack, memory }
167        }
168    }
169
170    impl Function for LtDigest {
171        fn rust_shadow(
172            &self,
173            stack: &mut Vec<BFieldElement>,
174            memory: &mut HashMap<BFieldElement, BFieldElement>,
175        ) -> Result<(), RustShadowError> {
176            let load_digest = |ptr: BFieldElement| {
177                Digest::new([
178                    memory[&ptr],
179                    memory[&(ptr + bfe!(1))],
180                    memory[&(ptr + bfe!(2))],
181                    memory[&(ptr + bfe!(3))],
182                    memory[&(ptr + bfe!(4))],
183                ])
184            };
185            let rhs_ptr = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
186            let lhs_ptr = stack.pop().ok_or(RustShadowError::StackUnderflow)?;
187            let rhs = load_digest(rhs_ptr);
188            let lhs = load_digest(lhs_ptr);
189
190            stack.push(bfe!((rhs < lhs) as u64));
191            Ok(())
192        }
193
194        fn pseudorandom_initial_state(
195            &self,
196            seed: [u8; 32],
197            bench_case: Option<BenchmarkCase>,
198        ) -> FunctionInitialState {
199            let mut rng = StdRng::from_seed(seed);
200
201            let a_digest = Digest::new([bfe!(1), bfe!(2), bfe!(3), bfe!(4), bfe!(5)]);
202            let another_digest = Digest::new([bfe!(6), bfe!(7), bfe!(8), bfe!(9), bfe!(10)]);
203
204            let (lhs, rhs) = match bench_case {
205                Some(BenchmarkCase::CommonCase) => (a_digest, another_digest),
206                Some(BenchmarkCase::WorstCase) => (a_digest, a_digest),
207                None => (rng.random(), rng.random()),
208            };
209            let rhs_ptr: BFieldElement = rng.random();
210            let lhs_ptr: BFieldElement = rng.random();
211
212            self.prepare_state(lhs_ptr, rhs_ptr, lhs, rhs)
213        }
214
215        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
216            let lhs_ptr = bfe!(0);
217            let rhs_ptr = bfe!(5);
218
219            let a_digest = Digest::new([bfe!(1), bfe!(2), bfe!(3), bfe!(4), bfe!(5)]);
220            let another_digest = Digest::new([bfe!(6), bfe!(7), bfe!(8), bfe!(9), bfe!(10)]);
221            let digests_on_lowest_addresses =
222                self.prepare_state(lhs_ptr, rhs_ptr, a_digest, another_digest);
223            let same_digest_values = self.prepare_state(lhs_ptr, rhs_ptr, a_digest, a_digest);
224
225            let mut adjacent_digest_pairs = vec![];
226            let mut rng = rand::rng();
227            for i in 0..Digest::LEN {
228                let a: Digest = rng.random();
229                let mut b: Digest = a;
230
231                // `lhs == rhs`
232                adjacent_digest_pairs.push(self.prepare_state(lhs_ptr, rhs_ptr, a, a));
233                adjacent_digest_pairs.push(self.prepare_state(lhs_ptr, rhs_ptr, b, b));
234
235                // `b` "bigger" than `a` by $2^i$ at index i
236                // (b might be smaller than a due to wrap-around)
237                for j in 0..64 {
238                    b = a;
239                    b.0[i] += bfe!(1u64 << j);
240                    adjacent_digest_pairs.push(self.prepare_state(lhs_ptr, rhs_ptr, a, b));
241                    adjacent_digest_pairs.push(self.prepare_state(lhs_ptr, rhs_ptr, b, a));
242                }
243            }
244
245            [
246                vec![digests_on_lowest_addresses, same_digest_values],
247                adjacent_digest_pairs,
248            ]
249            .concat()
250        }
251    }
252
253    #[macro_rules_attr::apply(test)]
254    fn rust_shadow() {
255        ShadowedFunction::new(LtDigest).test()
256    }
257}
258
259#[cfg(test)]
260mod benches {
261    use super::*;
262    use crate::test_prelude::*;
263
264    #[macro_rules_attr::apply(test)]
265    fn lt_digest_bench() {
266        ShadowedFunction::new(LtDigest).bench()
267    }
268}