secp256k1_zkp/zkp/
surjection_proof.rs

1use crate::ffi;
2use crate::from_hex;
3use crate::Verification;
4use crate::{Error, Generator, Secp256k1};
5use core::mem::size_of;
6use std::str;
7
8/// Represents a surjection proof.
9#[derive(Debug, PartialEq, Clone, Eq, Hash, PartialOrd, Ord)]
10pub struct SurjectionProof {
11    inner: ffi::SurjectionProof,
12}
13
14#[cfg(feature = "actual-rand")]
15mod with_rand {
16    use super::*;
17    use crate::{Signing, Tag, Tweak};
18    use ffi::CPtr;
19    use rand::Rng;
20
21    impl SurjectionProof {
22        /// Prove that a given tag - when blinded - is contained within another set of blinded tags.
23        ///
24        /// Mathematically, we are proving that there exists a surjective mapping between the domain and codomain of tags.
25        /// Blinding a tag produces a [`Generator`]. As such, to create this proof we need to provide the `[Generator]`s and the respective blinding factors that were used to create them.
26        pub fn new<C: Signing, R: Rng>(
27            secp: &Secp256k1<C>,
28            rng: &mut R,
29            codomain_tag: Tag,
30            codomain_blinding_factor: Tweak,
31            domain: &[(Generator, Tag, Tweak)],
32        ) -> Result<SurjectionProof, Error> {
33            let mut proof = ffi::SurjectionProof::new();
34
35            let mut seed = [0u8; 32];
36            rng.fill_bytes(&mut seed);
37
38            let mut domain_index = 0;
39            let max_iteration = 100;
40
41            let mut domain_blinded_tags = Vec::with_capacity(domain.len());
42            let mut domain_tags = Vec::with_capacity(domain.len());
43            let mut domain_blinding_factors = Vec::with_capacity(domain.len());
44
45            for (blinded_tag, tag, bf) in domain {
46                domain_blinded_tags.push(*blinded_tag.as_inner());
47                domain_tags.push(tag.into_inner());
48                domain_blinding_factors.push(*bf);
49            }
50
51            let ret = unsafe {
52                ffi::secp256k1_surjectionproof_initialize(
53                    secp.ctx().as_ptr(),
54                    &mut proof,
55                    &mut domain_index,
56                    domain_tags.as_ptr(),
57                    domain.len(),
58                    domain.len().min(3),
59                    codomain_tag.as_inner(),
60                    max_iteration,
61                    seed.as_ptr(),
62                )
63            };
64
65            if ret == 0 {
66                return Err(Error::CannotProveSurjection);
67            }
68
69            let codomain_blinded_tag =
70                Generator::new_blinded(secp, codomain_tag, codomain_blinding_factor);
71
72            let ret = unsafe {
73                ffi::secp256k1_surjectionproof_generate(
74                    secp.ctx().as_ptr(),
75                    &mut proof,
76                    domain_blinded_tags.as_ptr(),
77                    domain.len(),
78                    codomain_blinded_tag.as_inner(),
79                    domain_index,
80                    domain
81                        .get(domain_index)
82                        .ok_or(Error::CannotProveSurjection)?
83                        .2
84                        .as_c_ptr(), // TODO: Return dedicated error here?
85                    codomain_blinding_factor.as_c_ptr(),
86                )
87            };
88
89            if ret == 0 {
90                return Err(Error::CannotProveSurjection);
91            }
92
93            Ok(SurjectionProof { inner: proof })
94        }
95    }
96}
97
98impl SurjectionProof {
99    /// Creates a surjection proof from a slice of bytes.
100    pub fn from_slice(bytes: &[u8]) -> Result<Self, Error> {
101        let mut proof = ffi::SurjectionProof::new();
102
103        let ret = unsafe {
104            ffi::secp256k1_surjectionproof_parse(
105                ffi::secp256k1_context_no_precomp,
106                &mut proof,
107                bytes.as_ptr(),
108                bytes.len(),
109            )
110        };
111
112        if ret != 1 {
113            return Err(Error::InvalidSurjectionProof);
114        }
115
116        Ok(SurjectionProof { inner: proof })
117    }
118
119    /// Serializes a surjection proof.
120    ///
121    /// The format of this serialization is stable and platform-independent.
122    pub fn serialize(&self) -> Vec<u8> {
123        let mut size = unsafe {
124            ffi::secp256k1_surjectionproof_serialized_size(
125                ffi::secp256k1_context_no_precomp,
126                &self.inner,
127            )
128        };
129
130        let mut bytes = vec![0u8; size];
131
132        let ret = unsafe {
133            ffi::secp256k1_surjectionproof_serialize(
134                ffi::secp256k1_context_no_precomp,
135                bytes.as_mut_ptr(),
136                &mut size,
137                &self.inner,
138            )
139        };
140        assert_eq!(ret, 1, "failed to serialize surjection proof"); // This is safe as long as we correctly computed the size of the proof upfront using `secp256k1_surjectionproof_serialized_size`.
141
142        bytes
143    }
144
145    /// Find the length of surjection proof when serialized
146    #[allow(clippy::len_without_is_empty)]
147    pub fn len(&self) -> usize {
148        unsafe {
149            ffi::secp256k1_surjectionproof_serialized_size(
150                ffi::secp256k1_context_no_precomp,
151                &self.inner,
152            )
153        }
154    }
155
156    /// Whether the proof has zero length
157    ///
158    /// Always returns `false` since a surjection proof must contain at least
159    /// one 32-byte hash.
160    pub fn is_empty(&self) -> bool {
161        false
162    }
163
164    /// Verify a surjection proof.
165    #[must_use]
166    pub fn verify<C: Verification>(
167        &self,
168        secp: &Secp256k1<C>,
169        codomain: Generator,
170        domain: &[Generator],
171    ) -> bool {
172        // Safety: Generator and ffi::PublicKey are the same size and layout.
173        let domain_blinded_tags = unsafe {
174            debug_assert_eq!(size_of::<Generator>(), size_of::<ffi::PublicKey>());
175
176            &*(domain as *const [Generator] as *const [ffi::PublicKey])
177        };
178
179        let ret = unsafe {
180            ffi::secp256k1_surjectionproof_verify(
181                secp.ctx().as_ptr(),
182                &self.inner,
183                domain_blinded_tags.as_ptr(),
184                domain_blinded_tags.len(),
185                codomain.as_inner(),
186            )
187        };
188
189        ret == 1
190    }
191}
192
193#[cfg(feature = "hashes")]
194impl ::core::fmt::Display for SurjectionProof {
195    fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
196        use internals::hex::display::DisplayHex;
197
198        write!(f, "{:x}", &self.serialize().as_slice().as_hex())
199    }
200}
201
202impl str::FromStr for SurjectionProof {
203    type Err = Error;
204    fn from_str(s: &str) -> Result<SurjectionProof, Error> {
205        let mut res = vec![0u8; s.len() / 2];
206        match from_hex(s, &mut res) {
207            Ok(_) => SurjectionProof::from_slice(&res),
208            _ => Err(Error::InvalidSurjectionProof),
209        }
210    }
211}
212
213#[cfg(all(feature = "serde", feature = "hashes"))]
214impl ::serde::Serialize for SurjectionProof {
215    fn serialize<S: ::serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
216        if s.is_human_readable() {
217            s.collect_str(&self)
218        } else {
219            s.serialize_bytes(&self.serialize())
220        }
221    }
222}
223
224#[cfg(all(feature = "serde", feature = "hashes"))]
225impl<'de> ::serde::Deserialize<'de> for SurjectionProof {
226    fn deserialize<D: ::serde::Deserializer<'de>>(d: D) -> Result<SurjectionProof, D::Error> {
227        use crate::serde_util;
228
229        if d.is_human_readable() {
230            d.deserialize_str(serde_util::FromStrVisitor::new("an ASCII hex string"))
231        } else {
232            d.deserialize_bytes(serde_util::BytesVisitor::new(
233                "a bytestring",
234                SurjectionProof::from_slice,
235            ))
236        }
237    }
238}
239
240#[cfg(all(test, feature = "global-context"))] // use global context for convenience
241mod tests {
242    use super::*;
243    use crate::{Tag, Tweak, SECP256K1};
244    use rand::thread_rng;
245
246    #[cfg(target_arch = "wasm32")]
247    use wasm_bindgen_test::wasm_bindgen_test as test;
248
249    #[test]
250    fn test_create_and_verify_surjection_proof() {
251        // create three random tags
252        let (domain_tag_1, domain_blinded_tag_1, domain_bf_1) = random_blinded_tag();
253        let (domain_tag_2, domain_blinded_tag_2, domain_bf_2) = random_blinded_tag();
254        let (domain_tag_3, domain_blinded_tag_3, domain_bf_3) = random_blinded_tag();
255
256        // pick the first one as the codomain
257        let codomain_tag_1 = domain_tag_1;
258        let (codomain_blinded_tag_1, codomain_bf_1) = blind_tag(codomain_tag_1);
259
260        let proof = SurjectionProof::new(
261            SECP256K1,
262            &mut thread_rng(),
263            codomain_tag_1,
264            codomain_bf_1,
265            &[
266                (domain_blinded_tag_1, domain_tag_1, domain_bf_1),
267                (domain_blinded_tag_2, domain_tag_2, domain_bf_2),
268                (domain_blinded_tag_3, domain_tag_3, domain_bf_3),
269            ],
270        )
271        .unwrap();
272
273        assert!(proof.verify(
274            SECP256K1,
275            codomain_blinded_tag_1,
276            &[
277                domain_blinded_tag_1,
278                domain_blinded_tag_2,
279                domain_blinded_tag_3
280            ],
281        ))
282    }
283
284    #[test]
285    fn test_serialize_and_parse_surjection_proof() {
286        let (domain_tag_1, domain_blinded_tag_1, domain_bf_1) = random_blinded_tag();
287        let codomain_tag_1 = domain_tag_1;
288        let (_, codomain_bf_1) = blind_tag(codomain_tag_1);
289
290        let proof = SurjectionProof::new(
291            SECP256K1,
292            &mut thread_rng(),
293            codomain_tag_1,
294            codomain_bf_1,
295            &[(domain_blinded_tag_1, domain_tag_1, domain_bf_1)],
296        )
297        .unwrap();
298        let bytes = proof.serialize();
299        let parsed = SurjectionProof::from_slice(&bytes).unwrap();
300
301        assert_eq!(parsed, proof);
302
303        #[cfg(feature = "hashes")]
304        {
305            use std::str::FromStr;
306            use std::string::ToString;
307            let proof_str = proof.to_string();
308            assert_eq!(proof, SurjectionProof::from_str(&proof_str).unwrap());
309        }
310    }
311
312    fn random_blinded_tag() -> (Tag, Generator, Tweak) {
313        let tag = Tag::random();
314
315        let (blinded_tag, bf) = blind_tag(tag);
316
317        (tag, blinded_tag, bf)
318    }
319
320    fn blind_tag(tag: Tag) -> (Generator, Tweak) {
321        let bf = Tweak::new(&mut thread_rng());
322        let blinded_tag = Generator::new_blinded(SECP256K1, tag, bf);
323
324        (blinded_tag, bf)
325    }
326}