tasm_lib/hashing/
lt_digest.rs

1use tip5::Digest;
2use triton_vm::prelude::*;
3
4use crate::prelude::*;
5
6/// Compare two digests
7#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
8pub struct LtDigest;
9
10impl BasicSnippet for LtDigest {
11    fn inputs(&self) -> Vec<(DataType, String)> {
12        vec![
13            (DataType::VoidPointer, "*digest_lhs".to_owned()),
14            (DataType::VoidPointer, "*digest_rhs".to_owned()),
15        ]
16    }
17
18    fn outputs(&self) -> Vec<(DataType, String)> {
19        vec![(DataType::Bool, "digest_lhs > digest_rhs".to_owned())]
20    }
21
22    fn entrypoint(&self) -> String {
23        "tasmlib_hashing_lt_digest".to_owned()
24    }
25
26    fn code(&self, _library: &mut Library) -> Vec<LabelledInstruction> {
27        let entrypoint = self.entrypoint();
28
29        let loop_label = format!("{entrypoint}_short_circuiting_loop");
30
31        // result ∈ {0, 1} : {0: !(lhs > rhs), 1: lhs > rhs}
32        let compare_bfes_loop = triton_asm!(
33            // Invariant: _ result (*lhs - 1) *lhs[i] *rhs[i] g0 g1
34            {loop_label}:
35                hint result: bool = stack[5]
36                hint lhs_end_condition: BFieldElement = stack[4]
37                hint lhs_i_ptr = stack[3]
38                hint rhs_i_ptr = stack[2]
39
40                /* Check if we are done */
41                dup 4
42                dup 4
43                eq
44                skiz return
45
46                pop 2
47                // _ result (*lhs - 1) *lhs[i] *rhs[i]
48
49                read_mem 1
50                // _ result (*lhs - 1) *lhs[i] rhs[i] *rhs[i-1]
51
52                swap 2
53                read_mem 1
54                // _ result (*lhs - 1) *rhs[i-1] rhs[i] lhs[i] *lhs[i-1]
55
56                swap 3
57                swap 2
58                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] lhs[i] rhs[i]
59
60                dup 1 dup 1
61                eq
62                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] lhs[i] rhs[i] (lhs[i] == rhs[i])
63
64                skiz recurse
65                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] lhs[i] rhs[i]
66
67                split
68                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] lhs[i] rhs[i]_hi rhs[i]_lo
69
70                swap 2
71                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] rhs[i]_lo rhs[i]_hi lhs[i]
72
73                split
74                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] rhs[i]_lo rhs[i]_hi lhs[i]_hi lhs[i]_lo
75
76                /* Calculate rhs[i]_hi < lhs[i]_hi || rhs[i]_hi == rhs[i]_hi && rhs[i]_lo < lhs[i]_lo*/
77
78                dup 1 dup 3 lt
79                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] rhs[i]_lo rhs[i]_hi lhs[i]_hi lhs[i]_lo (lhs[i]_hi > rhs[i]_hi)
80
81                swap 4
82                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] (lhs[i]_hi > rhs[i]_hi) rhs[i]_hi lhs[i]_hi lhs[i]_lo rhs[i]_lo
83
84                lt
85                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] (lhs[i]_hi > rhs[i]_hi) rhs[i]_hi lhs[i]_hi (lhs[i]_lo > rhs[i]_lo)
86
87                swap 2 eq
88                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] (lhs[i]_hi > rhs[i]_hi) (lhs[i]_lo > rhs[i]_lo) (lhs[i]_hi == rhs[i]_hi)
89
90                mul
91                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] (lhs[i]_hi > rhs[i]_hi) ((lhs[i]_lo > rhs[i]_lo) && (lhs[i]_hi == rhs[i]_hi))
92
93                add
94                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] (lhs[i]_hi > rhs[i]_hi) || ((lhs[i]_lo > rhs[i]_lo) && (lhs[i]_hi == rhs[i]_hi))
95                // _ result (*lhs - 1) *lhs[i-1] *rhs[i-1] result'
96
97                swap 4
98                push 0
99                // _ result' (*lhs - 1) *lhs[i-1] *rhs[i-1] g0 g1
100
101                return
102        );
103
104        triton_asm!(
105            // BEFORE: _ *lhs *rhs
106            // AFTER:  _ (lhs > rhs)
107            {entrypoint}:
108                // _ *lhs *rhs
109
110                // goal:  _ result (*lhs - 1) *lhs[i] *rhs[i] g0 g1
111
112                push 0
113                swap 2
114                // _ 0 *rhs *lhs
115
116                push {Digest::LEN - 1} add
117                // _ 0 *rhs *lhs[4]
118
119                dup 0 push {-(Digest::LEN as isize)} add
120                // _ 0 *rhs *lhs[4] (*lhs - 1)
121
122                swap 2
123                // _ 0 (*lhs - 1) *lhs[4] *rhs
124
125                push {Digest::LEN - 1} add
126                // _ 0 (*lhs - 1) *lhs[4] *rhs[4]
127
128                push 0
129                push 0
130                // _ 0 (*lhs - 1) *lhs[4] *rhs[4] 0 0
131                // _ result (*lhs - 1) *lhs[4] *rhs[4] g0 g1
132
133                call {loop_label}
134                // _ result (*lhs - 1) *lhs[i] *rhs[i] g0 g1
135
136                pop 5
137                // _ result
138
139                return
140
141            {&compare_bfes_loop}
142        )
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use crate::memory::encode_to_memory;
150    use crate::test_prelude::*;
151
152    impl LtDigest {
153        fn prepare_state(
154            &self,
155            lhs_ptr: BFieldElement,
156            rhs_ptr: BFieldElement,
157            lhs: Digest,
158            rhs: Digest,
159        ) -> FunctionInitialState {
160            let mut memory = HashMap::default();
161            encode_to_memory(&mut memory, lhs_ptr, &lhs);
162            encode_to_memory(&mut memory, rhs_ptr, &rhs);
163
164            let stack = [self.init_stack_for_isolated_run(), vec![lhs_ptr, rhs_ptr]].concat();
165
166            FunctionInitialState { stack, memory }
167        }
168    }
169
170    impl Function for LtDigest {
171        fn rust_shadow(
172            &self,
173            stack: &mut Vec<BFieldElement>,
174            memory: &mut HashMap<BFieldElement, BFieldElement>,
175        ) {
176            let load_digest = |ptr: BFieldElement| {
177                Digest::new([
178                    memory[&ptr],
179                    memory[&(ptr + bfe!(1))],
180                    memory[&(ptr + bfe!(2))],
181                    memory[&(ptr + bfe!(3))],
182                    memory[&(ptr + bfe!(4))],
183                ])
184            };
185            let rhs_ptr = stack.pop().unwrap();
186            let lhs_ptr = stack.pop().unwrap();
187            let rhs = load_digest(rhs_ptr);
188            let lhs = load_digest(lhs_ptr);
189
190            stack.push(bfe!((rhs < lhs) as u64));
191        }
192
193        fn pseudorandom_initial_state(
194            &self,
195            seed: [u8; 32],
196            bench_case: Option<BenchmarkCase>,
197        ) -> FunctionInitialState {
198            let mut rng = StdRng::from_seed(seed);
199
200            let a_digest = Digest::new([bfe!(1), bfe!(2), bfe!(3), bfe!(4), bfe!(5)]);
201            let another_digest = Digest::new([bfe!(6), bfe!(7), bfe!(8), bfe!(9), bfe!(10)]);
202
203            let (lhs, rhs) = match bench_case {
204                Some(BenchmarkCase::CommonCase) => (a_digest, another_digest),
205                Some(BenchmarkCase::WorstCase) => (a_digest, a_digest),
206                None => (rng.random(), rng.random()),
207            };
208            let rhs_ptr: BFieldElement = rng.random();
209            let lhs_ptr: BFieldElement = rng.random();
210
211            self.prepare_state(lhs_ptr, rhs_ptr, lhs, rhs)
212        }
213
214        fn corner_case_initial_states(&self) -> Vec<FunctionInitialState> {
215            let lhs_ptr = bfe!(0);
216            let rhs_ptr = bfe!(5);
217
218            let a_digest = Digest::new([bfe!(1), bfe!(2), bfe!(3), bfe!(4), bfe!(5)]);
219            let another_digest = Digest::new([bfe!(6), bfe!(7), bfe!(8), bfe!(9), bfe!(10)]);
220            let digests_on_lowest_addresses =
221                self.prepare_state(lhs_ptr, rhs_ptr, a_digest, another_digest);
222            let same_digest_values = self.prepare_state(lhs_ptr, rhs_ptr, a_digest, a_digest);
223
224            let mut adjacent_digest_pairs = vec![];
225            let mut rng = rand::rng();
226            for i in 0..Digest::LEN {
227                let a: Digest = rng.random();
228                let mut b: Digest = a;
229
230                // `lhs == rhs`
231                adjacent_digest_pairs.push(self.prepare_state(lhs_ptr, rhs_ptr, a, a));
232                adjacent_digest_pairs.push(self.prepare_state(lhs_ptr, rhs_ptr, b, b));
233
234                // `b` "bigger" than `a` by $2^i$ at index i
235                // (b might be smaller than a due to wrap-around)
236                for j in 0..64 {
237                    b = a;
238                    b.0[i] += bfe!(1u64 << j);
239                    adjacent_digest_pairs.push(self.prepare_state(lhs_ptr, rhs_ptr, a, b));
240                    adjacent_digest_pairs.push(self.prepare_state(lhs_ptr, rhs_ptr, b, a));
241                }
242            }
243
244            [
245                vec![digests_on_lowest_addresses, same_digest_values],
246                adjacent_digest_pairs,
247            ]
248            .concat()
249        }
250    }
251
252    #[test]
253    fn rust_shadow() {
254        ShadowedFunction::new(LtDigest).test()
255    }
256}
257
258#[cfg(test)]
259mod benches {
260    use super::*;
261    use crate::test_prelude::*;
262
263    #[test]
264    fn lt_digest_bench() {
265        ShadowedFunction::new(LtDigest).bench()
266    }
267}