1use crate::ffi;
2use crate::from_hex;
3use crate::Verification;
4use crate::{Error, Generator, Secp256k1};
5use core::mem::size_of;
6use std::str;
7
8#[derive(Debug, PartialEq, Clone, Eq, Hash, PartialOrd, Ord)]
10pub struct SurjectionProof {
11 inner: ffi::SurjectionProof,
12}
13
14#[cfg(feature = "actual-rand")]
15mod with_rand {
16 use super::*;
17 use crate::{Signing, Tag, Tweak};
18 use ffi::CPtr;
19 use rand::Rng;
20
21 impl SurjectionProof {
22 pub fn new<C: Signing, R: Rng>(
27 secp: &Secp256k1<C>,
28 rng: &mut R,
29 codomain_tag: Tag,
30 codomain_blinding_factor: Tweak,
31 domain: &[(Generator, Tag, Tweak)],
32 ) -> Result<SurjectionProof, Error> {
33 let mut proof = ffi::SurjectionProof::new();
34
35 let mut seed = [0u8; 32];
36 rng.fill_bytes(&mut seed);
37
38 let mut domain_index = 0;
39 let max_iteration = 100;
40
41 let mut domain_blinded_tags = Vec::with_capacity(domain.len());
42 let mut domain_tags = Vec::with_capacity(domain.len());
43 let mut domain_blinding_factors = Vec::with_capacity(domain.len());
44
45 for (blinded_tag, tag, bf) in domain {
46 domain_blinded_tags.push(*blinded_tag.as_inner());
47 domain_tags.push(tag.into_inner());
48 domain_blinding_factors.push(*bf);
49 }
50
51 let ret = unsafe {
52 ffi::secp256k1_surjectionproof_initialize(
53 secp.ctx().as_ptr(),
54 &mut proof,
55 &mut domain_index,
56 domain_tags.as_ptr(),
57 domain.len(),
58 domain.len().min(3),
59 codomain_tag.as_inner(),
60 max_iteration,
61 seed.as_ptr(),
62 )
63 };
64
65 if ret == 0 {
66 return Err(Error::CannotProveSurjection);
67 }
68
69 let codomain_blinded_tag =
70 Generator::new_blinded(secp, codomain_tag, codomain_blinding_factor);
71
72 let ret = unsafe {
73 ffi::secp256k1_surjectionproof_generate(
74 secp.ctx().as_ptr(),
75 &mut proof,
76 domain_blinded_tags.as_ptr(),
77 domain.len(),
78 codomain_blinded_tag.as_inner(),
79 domain_index,
80 domain
81 .get(domain_index)
82 .ok_or(Error::CannotProveSurjection)?
83 .2
84 .as_c_ptr(), codomain_blinding_factor.as_c_ptr(),
86 )
87 };
88
89 if ret == 0 {
90 return Err(Error::CannotProveSurjection);
91 }
92
93 Ok(SurjectionProof { inner: proof })
94 }
95 }
96}
97
98impl SurjectionProof {
99 pub fn from_slice(bytes: &[u8]) -> Result<Self, Error> {
101 let mut proof = ffi::SurjectionProof::new();
102
103 let ret = unsafe {
104 ffi::secp256k1_surjectionproof_parse(
105 ffi::secp256k1_context_no_precomp,
106 &mut proof,
107 bytes.as_ptr(),
108 bytes.len(),
109 )
110 };
111
112 if ret != 1 {
113 return Err(Error::InvalidSurjectionProof);
114 }
115
116 Ok(SurjectionProof { inner: proof })
117 }
118
119 pub fn serialize(&self) -> Vec<u8> {
123 let mut size = unsafe {
124 ffi::secp256k1_surjectionproof_serialized_size(
125 ffi::secp256k1_context_no_precomp,
126 &self.inner,
127 )
128 };
129
130 let mut bytes = vec![0u8; size];
131
132 let ret = unsafe {
133 ffi::secp256k1_surjectionproof_serialize(
134 ffi::secp256k1_context_no_precomp,
135 bytes.as_mut_ptr(),
136 &mut size,
137 &self.inner,
138 )
139 };
140 assert_eq!(ret, 1, "failed to serialize surjection proof"); bytes
143 }
144
145 #[allow(clippy::len_without_is_empty)]
147 pub fn len(&self) -> usize {
148 unsafe {
149 ffi::secp256k1_surjectionproof_serialized_size(
150 ffi::secp256k1_context_no_precomp,
151 &self.inner,
152 )
153 }
154 }
155
156 pub fn is_empty(&self) -> bool {
161 false
162 }
163
164 #[must_use]
166 pub fn verify<C: Verification>(
167 &self,
168 secp: &Secp256k1<C>,
169 codomain: Generator,
170 domain: &[Generator],
171 ) -> bool {
172 let domain_blinded_tags = unsafe {
174 debug_assert_eq!(size_of::<Generator>(), size_of::<ffi::PublicKey>());
175
176 &*(domain as *const [Generator] as *const [ffi::PublicKey])
177 };
178
179 let ret = unsafe {
180 ffi::secp256k1_surjectionproof_verify(
181 secp.ctx().as_ptr(),
182 &self.inner,
183 domain_blinded_tags.as_ptr(),
184 domain_blinded_tags.len(),
185 codomain.as_inner(),
186 )
187 };
188
189 ret == 1
190 }
191}
192
193#[cfg(feature = "hashes")]
194impl ::core::fmt::Display for SurjectionProof {
195 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
196 use internals::hex::display::DisplayHex;
197
198 write!(f, "{:x}", &self.serialize().as_slice().as_hex())
199 }
200}
201
202impl str::FromStr for SurjectionProof {
203 type Err = Error;
204 fn from_str(s: &str) -> Result<SurjectionProof, Error> {
205 let mut res = vec![0u8; s.len() / 2];
206 match from_hex(s, &mut res) {
207 Ok(_) => SurjectionProof::from_slice(&res),
208 _ => Err(Error::InvalidSurjectionProof),
209 }
210 }
211}
212
213#[cfg(all(feature = "serde", feature = "hashes"))]
214impl ::serde::Serialize for SurjectionProof {
215 fn serialize<S: ::serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
216 if s.is_human_readable() {
217 s.collect_str(&self)
218 } else {
219 s.serialize_bytes(&self.serialize())
220 }
221 }
222}
223
224#[cfg(all(feature = "serde", feature = "hashes"))]
225impl<'de> ::serde::Deserialize<'de> for SurjectionProof {
226 fn deserialize<D: ::serde::Deserializer<'de>>(d: D) -> Result<SurjectionProof, D::Error> {
227 use crate::serde_util;
228
229 if d.is_human_readable() {
230 d.deserialize_str(serde_util::FromStrVisitor::new("an ASCII hex string"))
231 } else {
232 d.deserialize_bytes(serde_util::BytesVisitor::new(
233 "a bytestring",
234 SurjectionProof::from_slice,
235 ))
236 }
237 }
238}
239
240#[cfg(all(test, feature = "global-context"))] mod tests {
242 use super::*;
243 use crate::{Tag, Tweak, SECP256K1};
244 use rand::thread_rng;
245
246 #[cfg(target_arch = "wasm32")]
247 use wasm_bindgen_test::wasm_bindgen_test as test;
248
249 #[test]
250 fn test_create_and_verify_surjection_proof() {
251 let (domain_tag_1, domain_blinded_tag_1, domain_bf_1) = random_blinded_tag();
253 let (domain_tag_2, domain_blinded_tag_2, domain_bf_2) = random_blinded_tag();
254 let (domain_tag_3, domain_blinded_tag_3, domain_bf_3) = random_blinded_tag();
255
256 let codomain_tag_1 = domain_tag_1;
258 let (codomain_blinded_tag_1, codomain_bf_1) = blind_tag(codomain_tag_1);
259
260 let proof = SurjectionProof::new(
261 SECP256K1,
262 &mut thread_rng(),
263 codomain_tag_1,
264 codomain_bf_1,
265 &[
266 (domain_blinded_tag_1, domain_tag_1, domain_bf_1),
267 (domain_blinded_tag_2, domain_tag_2, domain_bf_2),
268 (domain_blinded_tag_3, domain_tag_3, domain_bf_3),
269 ],
270 )
271 .unwrap();
272
273 assert!(proof.verify(
274 SECP256K1,
275 codomain_blinded_tag_1,
276 &[
277 domain_blinded_tag_1,
278 domain_blinded_tag_2,
279 domain_blinded_tag_3
280 ],
281 ))
282 }
283
284 #[test]
285 fn test_serialize_and_parse_surjection_proof() {
286 let (domain_tag_1, domain_blinded_tag_1, domain_bf_1) = random_blinded_tag();
287 let codomain_tag_1 = domain_tag_1;
288 let (_, codomain_bf_1) = blind_tag(codomain_tag_1);
289
290 let proof = SurjectionProof::new(
291 SECP256K1,
292 &mut thread_rng(),
293 codomain_tag_1,
294 codomain_bf_1,
295 &[(domain_blinded_tag_1, domain_tag_1, domain_bf_1)],
296 )
297 .unwrap();
298 let bytes = proof.serialize();
299 let parsed = SurjectionProof::from_slice(&bytes).unwrap();
300
301 assert_eq!(parsed, proof);
302
303 #[cfg(feature = "hashes")]
304 {
305 use std::str::FromStr;
306 use std::string::ToString;
307 let proof_str = proof.to_string();
308 assert_eq!(proof, SurjectionProof::from_str(&proof_str).unwrap());
309 }
310 }
311
312 fn random_blinded_tag() -> (Tag, Generator, Tweak) {
313 let tag = Tag::random();
314
315 let (blinded_tag, bf) = blind_tag(tag);
316
317 (tag, blinded_tag, bf)
318 }
319
320 fn blind_tag(tag: Tag) -> (Generator, Tweak) {
321 let bf = Tweak::new(&mut thread_rng());
322 let blinded_tag = Generator::new_blinded(SECP256K1, tag, bf);
323
324 (blinded_tag, bf)
325 }
326}