sigma_compiler_core/
codegen.rs

1//! A module for generating the code produced by this macro.  This code
2//! will interact with the underlying `sigma` macro.
3
4use super::sigma::codegen::{StructField, StructFieldList};
5use super::syntax::*;
6use proc_macro2::TokenStream;
7use quote::{format_ident, quote};
8#[cfg(test)]
9use syn::parse_quote;
10use syn::Ident;
11
12/// The main struct to handle code generation for this macro.
13///
14/// Initialize a [`CodeGen`] with the [`SigmaCompSpec`] you get by
15/// parsing the macro input.  Pass it to the various transformations and
16/// statement handlers, which will both update the code it will
17/// generate, and modify the [`SigmaCompSpec`].  Then at the end, call
18/// [`CodeGen::generate`] with the modified [`SigmaCompSpec`] to generate the
19/// code output by this macro.
20pub struct CodeGen {
21    /// The protocol name specified in the `sigma_compiler` macro
22    /// invocation
23    proto_name: Ident,
24    /// The group name specified in the `sigma_compiler` macro
25    /// invocation
26    group_name: Ident,
27    /// The variables that were explicitly listed in the
28    /// `sigma_compiler` macro invocation
29    vars: TaggedVarDict,
30    /// A prefix that does not appear at the beginning of any variable
31    /// name in `vars`
32    unique_prefix: String,
33    /// Variables (not necessarily appearing in `vars`, since they may
34    /// be generated by the sigma_compiler itself) that the prover needs
35    /// to send to the verifier along with the proof.  These could
36    /// include commitments to bits in range proofs, for example.
37    sent_instance: StructFieldList,
38    /// Extra code that will be emitted in the `prove` function
39    prove_code: TokenStream,
40    /// Extra code that will be emitted in the `verify` function
41    verify_code: TokenStream,
42    /// Extra code that will be emitted in the `verify` function before
43    /// the `sent_instance` are deserialized.  This is where the verifier
44    /// sets the lengths of vector variables in the `sent_instance`.
45    verify_pre_instance_code: TokenStream,
46}
47
48impl CodeGen {
49    /// Find a prefix that does not appear at the beginning of any
50    /// variable name in `vars`
51    fn unique_prefix(vars: &TaggedVarDict) -> String {
52        'outer: for tag in 0usize.. {
53            let try_prefix = if tag == 0 {
54                "gen__".to_string()
55            } else {
56                format!("gen{}__", tag)
57            };
58            for v in vars.keys() {
59                if v.starts_with(&try_prefix) {
60                    continue 'outer;
61                }
62            }
63            return try_prefix;
64        }
65        // The compiler complains if this isn't here, but it will only
66        // get hit if vars contains at least usize::MAX entries, which
67        // isn't going to happen.
68        String::new()
69    }
70
71    /// Create a new [`CodeGen`] given the [`SigmaCompSpec`] you get by
72    /// parsing the macro input.
73    pub fn new(spec: &SigmaCompSpec) -> Self {
74        Self {
75            proto_name: spec.proto_name.clone(),
76            group_name: spec.group_name.clone(),
77            vars: spec.vars.clone(),
78            unique_prefix: Self::unique_prefix(&spec.vars),
79            sent_instance: StructFieldList::default(),
80            prove_code: quote! {},
81            verify_code: quote! {},
82            verify_pre_instance_code: quote! {},
83        }
84    }
85
86    #[cfg(test)]
87    /// Create an empty [`CodeGen`].  Primarily useful in testing.
88    pub fn new_empty() -> Self {
89        Self {
90            proto_name: parse_quote! { proto },
91            group_name: parse_quote! { G },
92            vars: TaggedVarDict::default(),
93            unique_prefix: "gen__".into(),
94            sent_instance: StructFieldList::default(),
95            prove_code: quote! {},
96            verify_code: quote! {},
97            verify_pre_instance_code: quote! {},
98        }
99    }
100
101    /// Create a new generated private Scalar variable to put in the
102    /// Witness.
103    ///
104    /// If you call this, you should also call
105    /// [`prove_append`](Self::prove_append) with code like `quote!{ let
106    /// #id = ... }` where `id` is the [`struct@Ident`] returned from
107    /// this function.
108    pub fn gen_scalar(
109        &self,
110        vars: &mut TaggedVarDict,
111        base: &Ident,
112        is_rand: bool,
113        is_vec: bool,
114    ) -> Ident {
115        let id = format_ident!("{}{}", self.unique_prefix, base);
116        vars.insert(
117            id.to_string(),
118            TaggedIdent::Scalar(TaggedScalar {
119                id: id.clone(),
120                is_pub: false,
121                is_rand,
122                is_vec,
123            }),
124        );
125        id
126    }
127
128    /// Create a new public Point variable to put in the Instance,
129    /// optionally marking it as needing to be sent from the prover to
130    /// the verifier along with the proof.
131    ///
132    /// If you call this function, you should also call
133    /// [`prove_append`](Self::prove_append) with code like `quote!{ let
134    /// #id = ... }` where `id` is the [`struct@Ident`] returned from
135    /// this function.  If `is_vec` is `true`, then you should also call
136    /// [`verify_pre_instance_append`](Self::verify_pre_instance_append)
137    /// with code like `quote!{ let mut #id = Vec::<Point>::new();
138    /// #id.resize(#len, Point::default()); }` where `len` is the number
139    /// of elements you expect to have in the vector (computed at
140    /// runtime, perhaps based on the values of public parameters).
141    pub fn gen_point(
142        &mut self,
143        vars: &mut TaggedVarDict,
144        base: &Ident,
145        is_vec: bool,
146        send_to_verifier: bool,
147    ) -> Ident {
148        let id = format_ident!("{}{}", self.unique_prefix, base);
149        vars.insert(
150            id.to_string(),
151            TaggedIdent::Point(TaggedPoint {
152                id: id.clone(),
153                is_cind: false,
154                is_const: false,
155                is_vec,
156            }),
157        );
158        if send_to_verifier {
159            if is_vec {
160                self.sent_instance.push_vecpoint(&id);
161            } else {
162                self.sent_instance.push_point(&id);
163            }
164        }
165        id
166    }
167
168    /// Create a new identifier, using the unique prefix
169    pub fn gen_ident(&self, base: &Ident) -> Ident {
170        format_ident!("{}{}", self.unique_prefix, base)
171    }
172
173    /// Append some code to the generated `prove` function
174    pub fn prove_append(&mut self, code: TokenStream) {
175        let prove_code = &self.prove_code;
176        self.prove_code = quote! {
177            #prove_code
178            #code
179        };
180    }
181
182    /// Append some code to the generated `verify` function
183    pub fn verify_append(&mut self, code: TokenStream) {
184        let verify_code = &self.verify_code;
185        self.verify_code = quote! {
186            #verify_code
187            #code
188        };
189    }
190
191    /// Append some code to the generated `verify` function to be run
192    /// before the `sent_instance` are deserialized
193    pub fn verify_pre_instance_append(&mut self, code: TokenStream) {
194        let verify_pre_instance_code = &self.verify_pre_instance_code;
195        self.verify_pre_instance_code = quote! {
196            #verify_pre_instance_code
197            #code
198        };
199    }
200
201    /// Append some code to both the generated `prove` and `verify`
202    /// functions
203    pub fn prove_verify_append(&mut self, code: TokenStream) {
204        let prove_code = &self.prove_code;
205        self.prove_code = quote! {
206            #prove_code
207            #code
208        };
209        let verify_code = &self.verify_code;
210        self.verify_code = quote! {
211            #verify_code
212            #code
213        };
214    }
215
216    /// Append some code to both the generated `prove` and `verify`
217    /// functions, the latter to be run before the `sent_instance` are
218    /// deserialized
219    pub fn prove_verify_pre_instance_append(&mut self, code: TokenStream) {
220        let prove_code = &self.prove_code;
221        self.prove_code = quote! {
222            #prove_code
223            #code
224        };
225        let verify_pre_instance_code = &self.verify_pre_instance_code;
226        self.verify_pre_instance_code = quote! {
227            #verify_pre_instance_code
228            #code
229        };
230    }
231
232    /// Extract (as [`String`]s) the code inserted by
233    /// [`prove_append`](Self::prove_append),
234    /// [`verify_append`](Self::verify_append), and
235    /// [`verify_pre_instance_append`](Self::verify_pre_instance_append).
236    pub fn code_strings(&self) -> (String, String, String) {
237        (
238            self.prove_code.to_string(),
239            self.verify_code.to_string(),
240            self.verify_pre_instance_code.to_string(),
241        )
242    }
243
244    /// Generate the code to be output by this macro.
245    ///
246    /// `emit_prover` and `emit_verifier` are as in
247    /// [`sigma_compiler_core`](super::sigma_compiler_core).
248    pub fn generate(
249        &self,
250        spec: &mut SigmaCompSpec,
251        emit_prover: bool,
252        emit_verifier: bool,
253    ) -> TokenStream {
254        let proto_name = &self.proto_name;
255        let group_name = &self.group_name;
256
257        let group_types = quote! {
258            use super::group;
259            pub type Scalar = <super::#group_name as group::Group>::Scalar;
260            pub type Point = super::#group_name;
261        };
262
263        // vardict contains the variables that were defined in the macro
264        // call to [`sigma_compiler`]
265        let vardict = taggedvardict_to_vardict(&self.vars);
266        // sigma_proofs_vardict contains the variables that we are passing
267        // to sigma_proofs.  We may have removed some via substitution, and
268        // we may have added some when compiling statements like range
269        // assertions into underlying linear combination assertions.
270        let sigma_proofs_vardict = taggedvardict_to_vardict(&spec.vars);
271
272        // Generate the code that uses the underlying sigma_proofs API
273        let mut sigma_proofs_codegen = super::sigma::codegen::CodeGen::new(
274            format_ident!("sigma"),
275            format_ident!("Point"),
276            &sigma_proofs_vardict,
277            &mut spec.statements,
278        );
279        let sigma_proofs_code = sigma_proofs_codegen.generate(emit_prover, emit_verifier);
280
281        let mut pub_instance_fields = StructFieldList::default();
282        pub_instance_fields.push_vars(&vardict, true);
283        let mut witness_fields = StructFieldList::default();
284        witness_fields.push_vars(&vardict, false);
285
286        let mut sigma_proofs_instance_fields = StructFieldList::default();
287        sigma_proofs_instance_fields.push_vars(&sigma_proofs_vardict, true);
288        let mut sigma_proofs_witness_fields = StructFieldList::default();
289        sigma_proofs_witness_fields.push_vars(&sigma_proofs_vardict, false);
290
291        // Generate the public instance struct definition
292        let instance_def = {
293            let decls = pub_instance_fields.field_decls();
294            #[cfg(feature = "dump")]
295            let dump_impl = {
296                let dump_chunks = pub_instance_fields.dump();
297                quote! {
298                    impl Instance {
299                        fn dump_scalar(s: &Scalar) {
300                            let bytes: &[u8] = &s.to_repr();
301                            print!("{:02x?}", bytes);
302                        }
303
304                        fn dump_point(p: &Point) {
305                            let bytes: &[u8] = &p.to_bytes();
306                            print!("{:02x?}", bytes);
307                        }
308
309                        pub fn dump(&self) {
310                            #dump_chunks
311                        }
312                    }
313                }
314            };
315            #[cfg(not(feature = "dump"))]
316            let dump_impl = {
317                quote! {}
318            };
319            quote! {
320                #[derive(Clone)]
321                pub struct Instance {
322                    #decls
323                }
324
325                #dump_impl
326            }
327        };
328
329        // Generate the witness struct definition
330        let witness_def = if emit_prover {
331            let decls = witness_fields.field_decls();
332            quote! {
333                #[derive(Clone)]
334                pub struct Witness {
335                    #decls
336                }
337            }
338        } else {
339            quote! {}
340        };
341
342        // Generate the prove function
343        let prove_func = if emit_prover {
344            let instance_ids = pub_instance_fields.field_list();
345            let witness_ids = witness_fields.field_list();
346            let sigma_proofs_instance_ids = sigma_proofs_instance_fields.field_list();
347            let sigma_proofs_witness_ids = sigma_proofs_witness_fields.field_list();
348            let prove_code = &self.prove_code;
349            let codegen_instance_var = format_ident!("{}sigma_instance", self.unique_prefix);
350            let codegen_witness_var = format_ident!("{}sigma_witness", self.unique_prefix);
351            let instance_var = format_ident!("{}instance", self.unique_prefix);
352            let witness_var = format_ident!("{}witness", self.unique_prefix);
353            let rng_var = format_ident!("{}rng", self.unique_prefix);
354            let proof_var = format_ident!("{}proof", self.unique_prefix);
355            let sid_var = format_ident!("{}session_id", self.unique_prefix);
356            let sent_instance_code = {
357                let chunks = self.sent_instance.fields.iter().map(|sf| match sf {
358                    StructField::Point(id) => quote! {
359                        #proof_var.extend(sigma_proofs::serialization::serialize_elements(
360                            std::slice::from_ref(&#codegen_instance_var.#id)
361                        ));
362                    },
363                    StructField::VecPoint(id) => quote! {
364                        #proof_var.extend(sigma_proofs::serialization::serialize_elements(
365                            &#codegen_instance_var.#id
366                        ));
367                    },
368                    _ => quote! {},
369                });
370                quote! { #(#chunks)* }
371            };
372
373            let dumper = if cfg!(feature = "dump") {
374                quote! {
375                    println!("prover instance = {{");
376                    #instance_var.dump();
377                    println!("}}");
378                }
379            } else {
380                quote! {}
381            };
382
383            quote! {
384                pub fn prove(
385                    #instance_var: &Instance,
386                    #witness_var: &Witness,
387                    #sid_var: &[u8],
388                    #rng_var: &mut (impl CryptoRng + RngCore),
389                ) -> Result<Vec<u8>, SigmaError> {
390                    #dumper
391                    let Instance { #instance_ids } = #instance_var.clone();
392                    let Witness { #witness_ids } = #witness_var.clone();
393                    #prove_code
394                    let mut #proof_var = Vec::<u8>::new();
395                    let #codegen_instance_var = sigma::Instance {
396                        #sigma_proofs_instance_ids
397                    };
398                    let #codegen_witness_var = sigma::Witness {
399                        #sigma_proofs_witness_ids
400                    };
401                    #sent_instance_code
402                    #proof_var.extend(
403                        sigma::prove(
404                            &#codegen_instance_var,
405                            &#codegen_witness_var,
406                            #sid_var,
407                            #rng_var,
408                        )?
409                    );
410                    Ok(#proof_var)
411                }
412            }
413        } else {
414            quote! {}
415        };
416
417        // Generate the verify function
418        let verify_func = if emit_verifier {
419            let instance_ids = pub_instance_fields.field_list();
420            let sigma_proofs_instance_ids = sigma_proofs_instance_fields.field_list();
421            let verify_pre_instance_code = &self.verify_pre_instance_code;
422            let verify_code = &self.verify_code;
423            let codegen_instance_var = format_ident!("{}sigma_instance", self.unique_prefix);
424            let element_len_var = format_ident!("{}element_len", self.unique_prefix);
425            let offset_var = format_ident!("{}proof_offset", self.unique_prefix);
426            let instance_var = format_ident!("{}instance", self.unique_prefix);
427            let proof_var = format_ident!("{}proof", self.unique_prefix);
428            let sid_var = format_ident!("{}session_id", self.unique_prefix);
429            let sent_instance_code = {
430                let element_len_code = if self.sent_instance.fields.is_empty() {
431                    quote! {}
432                } else {
433                    quote! {
434                        let #element_len_var =
435                            <Point as group::GroupEncoding>::Repr::default().as_ref().len();
436                    }
437                };
438
439                let chunks = self.sent_instance.fields.iter().map(|sf| match sf {
440                    StructField::Point(id) => quote! {
441                        let #id: Point = sigma_proofs::serialization::deserialize_elements(
442                                &#proof_var[#offset_var..],
443                                1,
444                            ).ok_or(SigmaError::VerificationFailure)?[0];
445                        #offset_var += #element_len_var;
446                    },
447                    StructField::VecPoint(id) => quote! {
448                        #id = sigma_proofs::serialization::deserialize_elements(
449                                &#proof_var[#offset_var..],
450                                #id.len(),
451                            ).ok_or(SigmaError::VerificationFailure)?;
452                        #offset_var += #element_len_var * #id.len();
453                    },
454                    _ => quote! {},
455                });
456
457                quote! {
458                    let mut #offset_var = 0usize;
459                    #element_len_code
460                    #(#chunks)*
461                }
462            };
463
464            let dumper = if cfg!(feature = "dump") {
465                quote! {
466                    println!("verifier instance = {{");
467                    #instance_var.dump();
468                    println!("}}");
469                }
470            } else {
471                quote! {}
472            };
473
474            quote! {
475                pub fn verify(
476                    #instance_var: &Instance,
477                    #proof_var: &[u8],
478                    #sid_var: &[u8],
479                ) -> Result<(), SigmaError> {
480                    #dumper
481                    let Instance { #instance_ids } = #instance_var.clone();
482                    #verify_pre_instance_code
483                    #sent_instance_code
484                    #verify_code
485                    let #codegen_instance_var = sigma::Instance {
486                        #sigma_proofs_instance_ids
487                    };
488                    sigma::verify(
489                        &#codegen_instance_var,
490                        &#proof_var[#offset_var..],
491                        #sid_var,
492                    )
493                }
494            }
495        } else {
496            quote! {}
497        };
498
499        // Output the generated module for this protocol
500        let dump_use = if cfg!(feature = "dump") {
501            quote! {
502                use group::GroupEncoding;
503            }
504        } else {
505            quote! {}
506        };
507        quote! {
508            #[allow(non_snake_case)]
509            pub mod #proto_name {
510                use super::sigma_compiler;
511                use sigma_compiler::group::Group;
512                use sigma_compiler::group::ff::{Field, PrimeField};
513                use sigma_compiler::group::ff::derive::subtle::ConditionallySelectable;
514                use sigma_compiler::rand::{CryptoRng, RngCore};
515                use sigma_compiler::sigma_proofs;
516                use sigma_compiler::sigma_proofs::errors::Error as SigmaError;
517                use sigma_compiler::vecutils::*;
518                use std::ops::Neg;
519                #dump_use
520
521                #group_types
522
523                #sigma_proofs_code
524
525                #instance_def
526                #witness_def
527                #prove_func
528                #verify_func
529            }
530        }
531    }
532}