1use super::combiners::StatementTree;
7use super::types::{expr_type_tokens_id_closure, AExprType, VarDict};
8use proc_macro2::TokenStream;
9use quote::{format_ident, quote, ToTokens};
10use std::collections::HashSet;
11use syn::{Expr, Ident};
12
13pub enum StructField {
15 Scalar(Ident),
16 VecScalar(Ident),
17 Point(Ident),
18 VecPoint(Ident),
19}
20
21#[derive(Default)]
23pub struct StructFieldList {
24 pub fields: Vec<StructField>,
25}
26
27impl StructFieldList {
28 pub fn push_scalar(&mut self, s: &Ident) {
29 self.fields.push(StructField::Scalar(s.clone()));
30 }
31 pub fn push_vecscalar(&mut self, s: &Ident) {
32 self.fields.push(StructField::VecScalar(s.clone()));
33 }
34 pub fn push_point(&mut self, s: &Ident) {
35 self.fields.push(StructField::Point(s.clone()));
36 }
37 pub fn push_vecpoint(&mut self, s: &Ident) {
38 self.fields.push(StructField::VecPoint(s.clone()));
39 }
40 pub fn push_vars(&mut self, vars: &VarDict, for_instance: bool) {
41 for (id, ti) in vars.iter() {
42 match ti {
43 AExprType::Scalar { is_pub, is_vec, .. } => {
44 if *is_pub == for_instance {
45 if *is_vec {
46 self.push_vecscalar(&format_ident!("{}", id))
47 } else {
48 self.push_scalar(&format_ident!("{}", id))
49 }
50 }
51 }
52 AExprType::Point { is_vec, .. } => {
53 if for_instance {
54 if *is_vec {
55 self.push_vecpoint(&format_ident!("{}", id))
56 } else {
57 self.push_point(&format_ident!("{}", id))
58 }
59 }
60 }
61 }
62 }
63 }
64 #[cfg(feature = "dump")]
65 pub fn dump(&self) -> impl ToTokens {
67 let dump_chunks = self.fields.iter().map(|f| match f {
68 StructField::Scalar(id) => quote! {
69 print!(" {}: ", stringify!(#id));
70 Instance::dump_scalar(&self.#id);
71 println!("");
72 },
73 StructField::VecScalar(id) => quote! {
74 print!(" {}: [", stringify!(#id));
75 for s in self.#id.iter() {
76 print!(" ");
77 Instance::dump_scalar(s);
78 println!(",");
79 }
80 println!(" ]");
81 },
82 StructField::Point(id) => quote! {
83 print!(" {}: ", stringify!(#id));
84 Instance::dump_point(&self.#id);
85 println!("");
86 },
87 StructField::VecPoint(id) => quote! {
88 print!(" {}: [", stringify!(#id));
89 for p in self.#id.iter() {
90 print!(" ");
91 Instance::dump_point(p);
92 println!(",");
93 }
94 println!(" ]");
95 },
96 });
97 quote! { #(#dump_chunks)* }
98 }
99 pub fn field_decls(&self) -> impl ToTokens {
102 let decls = self.fields.iter().map(|f| match f {
103 StructField::Scalar(id) => quote! {
104 pub #id: Scalar,
105 },
106 StructField::VecScalar(id) => quote! {
107 pub #id: Vec<Scalar>,
108 },
109 StructField::Point(id) => quote! {
110 pub #id: Point,
111 },
112 StructField::VecPoint(id) => quote! {
113 pub #id: Vec<Point>,
114 },
115 });
116 quote! { #(#decls)* }
117 }
118 pub fn field_list(&self) -> impl ToTokens {
120 let field_ids = self.fields.iter().map(|f| match f {
121 StructField::Scalar(id) => quote! {
122 #id,
123 },
124 StructField::VecScalar(id) => quote! {
125 #id,
126 },
127 StructField::Point(id) => quote! {
128 #id,
129 },
130 StructField::VecPoint(id) => quote! {
131 #id,
132 },
133 });
134 quote! { #(#field_ids)* }
135 }
136}
137
138pub struct CodeGen<'a> {
140 proto_name: Ident,
141 group_name: Ident,
142 vars: &'a VarDict,
143 unique_prefix: String,
144 statements: &'a mut StatementTree,
145}
146
147impl<'a> CodeGen<'a> {
148 fn unique_prefix(vars: &VarDict) -> String {
151 'outer: for tag in 0usize.. {
152 let try_prefix = if tag == 0 {
153 "sigma__".to_string()
154 } else {
155 format!("sigma{}__", tag)
156 };
157 for v in vars.keys() {
158 if v.starts_with(&try_prefix) {
159 continue 'outer;
160 }
161 }
162 return try_prefix;
163 }
164 String::new()
168 }
169
170 pub fn new(
171 proto_name: Ident,
172 group_name: Ident,
173 vars: &'a VarDict,
174 statements: &'a mut StatementTree,
175 ) -> Self {
176 Self {
177 proto_name,
178 group_name,
179 vars,
180 unique_prefix: Self::unique_prefix(vars),
181 statements,
182 }
183 }
184
185 fn linear_relation_codegen(&self, exprs: &[&Expr]) -> (TokenStream, TokenStream) {
192 let instance_var = format_ident!("{}instance", self.unique_prefix);
193 let lr_var = format_ident!("{}lr", self.unique_prefix);
194 let mut allocated_vars: HashSet<Ident> = HashSet::new();
195 let mut param_vec_code = quote! {};
196 let mut witness_vec_code = quote! {};
197 let mut witness_code = quote! {};
198 let mut scalar_allocs = quote! {};
199 let mut element_allocs = quote! {};
200 let mut eq_code = quote! {};
201 let mut element_assigns = quote! {};
202
203 for (i, expr) in exprs.iter().enumerate() {
204 let eq_id = format_ident!("{}eq{}", self.unique_prefix, i + 1);
205 let vec_index_var = format_ident!("{}i", self.unique_prefix);
206 let vec_len_var = format_ident!("{}veclen{}", self.unique_prefix, i + 1);
207
208 let mut vec_param_vars: HashSet<Ident> = HashSet::new();
211 let mut vec_witness_vars: HashSet<Ident> = HashSet::new();
212
213 let Expr::Assign(syn::ExprAssign { left, right, .. }) = expr else {
221 let expr_str = quote! { #expr }.to_string();
222 panic!("Unrecognized expression: {expr_str}");
223 };
224 let (left_type, left_tokens) =
225 expr_type_tokens_id_closure(self.vars, left, &mut |id, id_type| match id_type {
226 AExprType::Scalar { is_pub: false, .. } => {
227 panic!("Left side of = contains a private Scalar");
228 }
229 AExprType::Scalar {
230 is_vec: false,
231 is_pub: true,
232 ..
233 }
234 | AExprType::Point { is_vec: false, .. } => Ok(quote! {#instance_var.#id}),
235 AExprType::Scalar {
236 is_vec: true,
237 is_pub: true,
238 ..
239 }
240 | AExprType::Point { is_vec: true, .. } => {
241 vec_param_vars.insert(id.clone());
242 Ok(quote! {#instance_var.#id})
243 }
244 })
245 .unwrap();
246 let AExprType::Point {
247 is_pub: true,
248 is_vec: left_is_vec,
249 } = left_type
250 else {
251 let expr_str = quote! { #expr }.to_string();
252 panic!("Left side of = does not evaluate to a public point: {expr_str}");
253 };
254 let Ok((right_type, right_tokens)) =
255 expr_type_tokens_id_closure(self.vars, right, &mut |id, id_type| match id_type {
256 AExprType::Scalar {
257 is_vec: false,
258 is_pub: false,
259 ..
260 } => {
261 if allocated_vars.insert(id.clone()) {
262 scalar_allocs = quote! {
263 #scalar_allocs
264 let #id = #lr_var.allocate_scalar();
265 };
266 witness_code = quote! {
267 #witness_code
268 witnessvec.push(witness.#id);
269 };
270 }
271 Ok(quote! {#id})
272 }
273 AExprType::Scalar {
274 is_vec: false,
275 is_pub: true,
276 ..
277 } => Ok(quote! {#instance_var.#id}),
278 AExprType::Scalar {
279 is_vec: true,
280 is_pub: false,
281 ..
282 } => {
283 vec_witness_vars.insert(id.clone());
284 if allocated_vars.insert(id.clone()) {
285 scalar_allocs = quote! {
286 #scalar_allocs
287 let #id = (0..#vec_len_var)
288 .map(|i| #lr_var.allocate_scalar())
289 .collect::<Vec<_>>();
290 };
291 witness_code = quote! {
292 #witness_code
293 witnessvec.extend(witness.#id.clone());
294 };
295 }
296 Ok(quote! { #id })
297 }
298 AExprType::Scalar {
299 is_vec: true,
300 is_pub: true,
301 ..
302 } => {
303 vec_param_vars.insert(id.clone());
304 Ok(quote! {#instance_var.#id})
305 }
306 AExprType::Point { is_vec: false, .. } => {
307 if allocated_vars.insert(id.clone()) {
308 element_allocs = quote! {
309 #element_allocs
310 let #id = #lr_var.allocate_element();
311 };
312 element_assigns = quote! {
313 #element_assigns
314 #lr_var.set_element(#id, #instance_var.#id);
315 };
316 }
317 Ok(quote! {#id})
318 }
319 AExprType::Point { is_vec: true, .. } => {
320 vec_param_vars.insert(id.clone());
321 if allocated_vars.insert(id.clone()) {
322 element_allocs = quote! {
323 #element_allocs
324 let #id = (0..#vec_len_var)
325 .map(|#vec_index_var| #lr_var.allocate_element())
326 .collect::<Vec<_>>();
327 };
328 element_assigns = quote! {
329 #element_assigns
330 for #vec_index_var in 0..#vec_len_var {
331 #lr_var.set_element(
332 #id[#vec_index_var],
333 #instance_var.#id[#vec_index_var],
334 );
335 }
336 };
337 }
338 Ok(quote! { #id })
339 }
340 })
341 else {
342 let expr_str = quote! { #expr }.to_string();
343 panic!("Right side of = is not a valid arithmetic expression: {expr_str}");
344 };
345 let AExprType::Point {
346 is_vec: right_is_vec,
347 ..
348 } = right_type
349 else {
350 let expr_str = quote! { #expr }.to_string();
351 panic!("Right side of = does not evaluate to a Point: {expr_str}");
352 };
353 if left_is_vec != right_is_vec {
354 let expr_str = quote! { #expr }.to_string();
355 panic!("Only one side of = is a vector expression: {expr_str}");
356 }
357 let vec_param_varvec = Vec::from_iter(vec_param_vars);
358 let vec_witness_varvec = Vec::from_iter(vec_witness_vars);
359
360 if !vec_param_varvec.is_empty() {
361 let firstvar = &vec_param_varvec[0];
362 param_vec_code = quote! {
363 #param_vec_code
364 let #vec_len_var = #instance_var.#firstvar.len();
365 };
366 for thisvar in vec_param_varvec.iter().skip(1) {
367 param_vec_code = quote! {
368 #param_vec_code
369 if #vec_len_var != #instance_var.#thisvar.len() {
370 eprintln!(
371 "Instance variables {} and {} must have the same length",
372 stringify!(#firstvar),
373 stringify!(#thisvar),
374 );
375 return Err(SigmaError::VerificationFailure);
376 }
377 };
378 }
379 if !vec_witness_varvec.is_empty() {
380 witness_vec_code = quote! {
381 #witness_vec_code
382 let #vec_len_var = instance.#firstvar.len();
383 };
384 }
385 for witvar in vec_witness_varvec {
386 witness_vec_code = quote! {
387 #witness_vec_code
388 if #vec_len_var != witness.#witvar.len() {
389 eprintln!(
390 "Instance variables {} and {} must have the same length",
391 stringify!(#firstvar),
392 stringify!(#witvar),
393 );
394 return Err(SigmaError::VerificationFailure);
395 }
396 }
397 }
398 };
399 if right_is_vec {
400 eq_code = quote! {
401 #eq_code
402 let #eq_id = (#right_tokens)
403 .iter()
404 .cloned()
405 .map(|lr| #lr_var.allocate_eq(lr))
406 .collect::<Vec<_>>();
407 };
408 element_assigns = quote! {
409 #element_assigns
410 (#left_tokens)
411 .iter()
412 .zip(#eq_id.iter())
413 .for_each(|(l,eq)| #lr_var.set_element(*eq, *l));
414 };
415 } else {
416 eq_code = quote! {
417 #eq_code
418 let #eq_id = #lr_var.allocate_eq(#right_tokens);
419 };
420 element_assigns = quote! {
421 #element_assigns
422 #lr_var.set_element(#eq_id, #left_tokens);
423 }
424 }
425 }
426
427 (
428 quote! {
429 {
430 let mut #lr_var = LinearRelation::<Point>::new();
431 #param_vec_code
432 #scalar_allocs
433 #element_allocs
434 #eq_code
435 #element_assigns
436
437 SigmaOk(ComposedRelation::try_from(#lr_var).unwrap())
438 }
439 },
440 quote! {
441 {
442 #witness_vec_code
443 let mut witnessvec = Vec::new();
444 #witness_code
445 SigmaOk(ComposedWitness::Simple(witnessvec))
446 }
447 },
448 )
449 }
450
451 fn proto_witness_codegen(&self, statement: &StatementTree) -> (TokenStream, TokenStream) {
460 match statement {
461 StatementTree::Leaf(_) if statement.is_leaf_true() => (
464 quote! {
465 Ok(ComposedRelation::try_from(LinearRelation::<Point>::new()).unwrap())
466 },
467 quote! {
468 Ok(ComposedWitness::Simple(vec![]))
469 },
470 ),
471 StatementTree::Leaf(leafexpr) => {
474 self.linear_relation_codegen(std::slice::from_ref(&leafexpr))
475 }
476 StatementTree::And(stvec) => {
481 let mut leaves: Vec<&Expr> = Vec::new();
482 let mut others: Vec<&StatementTree> = Vec::new();
483 for st in stvec {
484 match st {
485 StatementTree::Leaf(le) => leaves.push(le),
486 _ => others.push(st),
487 }
488 }
489 let (proto_code, witness_code) = self.linear_relation_codegen(&leaves);
490 if others.is_empty() {
491 (proto_code, witness_code)
492 } else {
493 let (others_proto, others_witness): (Vec<TokenStream>, Vec<TokenStream>) =
494 others
495 .iter()
496 .map(|st| self.proto_witness_codegen(st))
497 .unzip();
498 (
499 quote! {
500 SigmaOk(ComposedRelation::and([
501 #proto_code?,
502 #(#others_proto?,)*
503 ]))
504 },
505 quote! {
506 SigmaOk(ComposedWitness::and([
507 #witness_code?,
508 #(#others_witness?,)*
509 ]))
510 },
511 )
512 }
513 }
514 StatementTree::Or(stvec) => {
515 let (proto, witness): (Vec<TokenStream>, Vec<TokenStream>) = stvec
516 .iter()
517 .map(|st| self.proto_witness_codegen(st))
518 .unzip();
519 (
520 quote! {
521 SigmaOk(ComposedRelation::or([
522 #(#proto?,)*
523 ]))
524 },
525 quote! {
526 SigmaOk(ComposedWitness::or([
527 #(#witness?,)*
528 ]))
529 },
530 )
531 }
532 StatementTree::Thresh(_thresh, _stvec) => {
533 todo! {"Thresh not yet implemented"};
534 }
535 }
536 }
537
538 pub fn generate(&mut self, emit_prover: bool, emit_verifier: bool) -> TokenStream {
544 let proto_name = &self.proto_name;
545 let group_name = &self.group_name;
546
547 let group_types = quote! {
548 use super::group;
549 pub type Scalar = <super::#group_name as group::Group>::Scalar;
550 pub type Point = super::#group_name;
551 };
552
553 self.statements.flatten_ands();
555
556 let mut pub_instance_fields = StructFieldList::default();
557 pub_instance_fields.push_vars(self.vars, true);
558
559 let instance_def = {
561 let decls = pub_instance_fields.field_decls();
562 #[cfg(feature = "dump")]
563 let dump_impl = {
564 let dump_chunks = pub_instance_fields.dump();
565 quote! {
566 impl Instance {
567 fn dump_scalar(s: &Scalar) {
568 let bytes: &[u8] = &s.to_repr();
569 print!("{:02x?}", bytes);
570 }
571
572 fn dump_point(p: &Point) {
573 let bytes: &[u8] = &p.to_bytes();
574 print!("{:02x?}", bytes);
575 }
576
577 pub fn dump(&self) {
578 #dump_chunks
579 }
580 }
581 }
582 };
583 #[cfg(not(feature = "dump"))]
584 let dump_impl = {
585 quote! {}
586 };
587 quote! {
588 #[derive(Clone)]
589 pub struct Instance {
590 #decls
591 }
592
593 #dump_impl
594 }
595 };
596
597 let mut witness_fields = StructFieldList::default();
598 witness_fields.push_vars(self.vars, false);
599
600 let witness_def = if emit_prover {
602 let decls = witness_fields.field_decls();
603 quote! {
604 #[derive(Clone)]
605 pub struct Witness {
606 #decls
607 }
608 }
609 } else {
610 quote! {}
611 };
612
613 let (protocol_code, witness_code) = self.proto_witness_codegen(self.statements);
614
615 let protocol_func = {
617 let instance_var = format_ident!("{}instance", self.unique_prefix);
618
619 quote! {
620 fn protocol(
621 #instance_var: &Instance,
622 ) -> SigmaResult<ComposedRelation<Point>> {
623 #protocol_code
624 }
625 }
626 };
627
628 let witness_func = if emit_prover {
630 quote! {
631 fn protocol_witness(
632 instance: &Instance,
633 witness: &Witness,
634 ) -> SigmaResult<ComposedWitness<Point>> {
635 #witness_code
636 }
637 }
638 } else {
639 quote! {}
640 };
641
642 let prove_func = if emit_prover {
644 let instance_var = format_ident!("{}instance", self.unique_prefix);
645 let witness_var = format_ident!("{}witness", self.unique_prefix);
646 let session_id_var = format_ident!("{}session_id", self.unique_prefix);
647 let rng_var = format_ident!("{}rng", self.unique_prefix);
648 let proto_var = format_ident!("{}proto", self.unique_prefix);
649 let proto_witness_var = format_ident!("{}proto_witness", self.unique_prefix);
650 let nizk_var = format_ident!("{}nizk", self.unique_prefix);
651
652 let dumper = if cfg!(feature = "dump") {
653 quote! {
654 println!("prover instance = {{");
655 #instance_var.dump();
656 println!("}}");
657 }
658 } else {
659 quote! {}
660 };
661
662 quote! {
663 pub fn prove(
664 #instance_var: &Instance,
665 #witness_var: &Witness,
666 #session_id_var: &[u8],
667 #rng_var: &mut (impl CryptoRng + RngCore),
668 ) -> SigmaResult<Vec<u8>> {
669 #dumper
670 let #proto_var = protocol(#instance_var)?;
671 let #proto_witness_var = protocol_witness(#instance_var, #witness_var)?;
672 let #nizk_var = #proto_var.into_nizk(#session_id_var);
673
674 #nizk_var.prove_batchable(&#proto_witness_var, #rng_var)
675 }
676 }
677 } else {
678 quote! {}
679 };
680
681 let verify_func = if emit_verifier {
683 let instance_var = format_ident!("{}instance", self.unique_prefix);
684 let proof_var = format_ident!("{}proof", self.unique_prefix);
685 let session_id_var = format_ident!("{}session_id", self.unique_prefix);
686 let proto_var = format_ident!("{}proto", self.unique_prefix);
687 let nizk_var = format_ident!("{}nizk", self.unique_prefix);
688
689 let dumper = if cfg!(feature = "dump") {
690 quote! {
691 println!("verifier instance = {{");
692 #instance_var.dump();
693 println!("}}");
694 }
695 } else {
696 quote! {}
697 };
698
699 quote! {
700 pub fn verify(
701 #instance_var: &Instance,
702 #proof_var: &[u8],
703 #session_id_var: &[u8],
704 ) -> SigmaResult<()> {
705 #dumper
706 let #proto_var = protocol(#instance_var)?;
707 let #nizk_var = #proto_var.into_nizk(#session_id_var);
708
709 #nizk_var.verify_batchable(#proof_var)
710 }
711 }
712 } else {
713 quote! {}
714 };
715
716 let dump_use = if cfg!(feature = "dump") {
718 quote! {
719 use group::GroupEncoding;
720 }
721 } else {
722 quote! {}
723 };
724 quote! {
725 #[allow(non_snake_case)]
726 pub mod #proto_name {
727 use super::sigma_compiler;
728 use sigma_compiler::sigma_proofs;
729 use sigma_compiler::group::ff::PrimeField;
730 use sigma_compiler::rand::{CryptoRng, RngCore};
731 use sigma_compiler::subtle::CtOption;
732 use sigma_compiler::vecutils::*;
733 use sigma_proofs::{
734 composition::{ComposedRelation, ComposedWitness},
735 errors::Error as SigmaError,
736 errors::Ok as SigmaOk,
737 errors::Result as SigmaResult,
738 LinearRelation, Nizk,
739 };
740 use std::ops::Neg;
741 #dump_use
742
743 #group_types
744 #instance_def
745 #witness_def
746
747 #protocol_func
748 #witness_func
749 #prove_func
750 #verify_func
751 }
752 }
753 }
754}