1use winterfell::{
28 crypto::{hashers::Rp64_256, DefaultRandomCoin, MerkleTree},
29 math::{fields::f64::BaseElement, FieldElement, ToElements},
30 matrix::ColMatrix,
31 AcceptableOptions, Air, AirContext, Assertion, BatchingMethod, DefaultConstraintCommitment,
32 DefaultConstraintEvaluator, DefaultTraceLde, EvaluationFrame, FieldExtension, Proof,
33 ProofOptions, Prover, StarkDomain, TraceInfo, TracePolyTable, TraceTable,
34 TransitionConstraintDegree,
35};
36
37use winterfell::crypto::hashers::Rp64_256 as Rescue;
39
40const STATE_W: usize = 12;
46const NUM_ROUNDS: usize = 7;
48const ROWS_PER_HASH: usize = NUM_ROUNDS * 2 + 1; const NUM_HASHES: usize = 5;
52const HASH_ROWS: usize = NUM_HASHES * ROWS_PER_HASH; const CONSERVATION_ROW: usize = HASH_ROWS; pub const TRACE_LEN: usize = 128;
58pub const TRACE_W: usize = STATE_W + 1;
60const SEL_COL: usize = STATE_W; fn hash_start(h: usize) -> usize {
70 h * ROWS_PER_HASH
71}
72
73#[derive(Clone, Debug)]
82pub struct CtPublicInputs {
83 pub s_old: [BaseElement; 4],
85 pub s_new: [BaseElement; 4],
87 pub r_old: [BaseElement; 4],
89 pub r_new: [BaseElement; 4],
91 pub amt: [BaseElement; 4],
93}
94
95impl ToElements<BaseElement> for CtPublicInputs {
96 fn to_elements(&self) -> Vec<BaseElement> {
97 let mut v = Vec::with_capacity(4 * NUM_HASHES);
98 for d in [
99 &self.s_old,
100 &self.s_new,
101 &self.r_old,
102 &self.r_new,
103 &self.amt,
104 ] {
105 v.extend_from_slice(d);
106 }
107 v
108 }
109}
110
111pub struct CtAir {
116 context: AirContext<BaseElement>,
117 pub_inputs: CtPublicInputs,
118}
119
120impl Air for CtAir {
121 type BaseField = BaseElement;
122 type PublicInputs = CtPublicInputs;
123
124 fn new(trace_info: TraceInfo, pub_inputs: CtPublicInputs, options: ProofOptions) -> Self {
125 let mut degrees: Vec<TransitionConstraintDegree> = (0..STATE_W)
134 .map(|_| TransitionConstraintDegree::with_cycles(7, vec![128, 128]))
135 .collect();
136 degrees.push(TransitionConstraintDegree::new(1)); for _ in 0..7 {
138 degrees.push(TransitionConstraintDegree::with_cycles(1, vec![128]));
139 }
140
141 Self {
142 context: AirContext::new(trace_info, degrees, 4 * NUM_HASHES, options),
143 pub_inputs,
144 }
145 }
146
147 fn context(&self) -> &AirContext<BaseElement> {
148 &self.context
149 }
150
151 fn evaluate_transition<E: FieldElement<BaseField = BaseElement>>(
152 &self,
153 frame: &EvaluationFrame<E>,
154 periodic_values: &[E],
155 result: &mut [E],
156 ) {
157 let cur = frame.current();
158 let next = frame.next();
159
160 let ark1 = &periodic_values[0..STATE_W];
166 let ark2 = &periodic_values[STATE_W..2 * STATE_W];
167 let is_first_half = periodic_values[2 * STATE_W];
168 let is_conservation = periodic_values[2 * STATE_W + 1];
169 let is_active = periodic_values[2 * STATE_W + 2];
170 let one = E::ONE;
171 let zero = E::ZERO;
172
173 let mds = Rescue::MDS;
184
185 let cur_pow7: Vec<E> = (0..STATE_W)
191 .map(|i| {
192 let c = cur[i];
193 let c2 = c * c;
194 let c4 = c2 * c2;
195 c4 * c2 * c
196 })
197 .collect();
198
199 let mut mds_cur_pow7 = vec![E::ZERO; STATE_W];
201 for i in 0..STATE_W {
202 for j in 0..STATE_W {
203 mds_cur_pow7[i] += E::from(mds[i][j]) * cur_pow7[j];
204 }
205 }
206
207 let _next_pow7: Vec<E> = (0..STATE_W)
209 .map(|i| {
210 let n = next[i];
211 let n2 = n * n;
212 let n4 = n2 * n2;
213 n4 * n2 * n
214 })
215 .collect();
216
217 let inv_mds = Rescue::INV_MDS;
222 let next_sub_ark2: Vec<E> = (0..STATE_W).map(|i| next[i] - ark2[i]).collect();
223 let mut inv_mds_next_sub_ark2 = vec![E::ZERO; STATE_W];
224 for i in 0..STATE_W {
225 for j in 0..STATE_W {
226 inv_mds_next_sub_ark2[i] += E::from(inv_mds[i][j]) * next_sub_ark2[j];
227 }
228 }
229 let inv_mds_pow7: Vec<E> = (0..STATE_W)
231 .map(|i| {
232 let v = inv_mds_next_sub_ark2[i];
233 let v2 = v * v;
234 let v4 = v2 * v2;
235 v4 * v2 * v
236 })
237 .collect();
238
239 for i in 0..STATE_W {
240 let fwd = next[i] - mds_cur_pow7[i] - ark1[i];
242 let bwd = cur[i] - inv_mds_pow7[i];
244 result[i] = is_active * (is_first_half * fwd + (one - is_first_half) * bwd);
246 }
247
248 result[SEL_COL] = zero;
250
251 result[STATE_W + 1] = is_conservation * (cur[0] - cur[8] - cur[2]);
254 result[STATE_W + 2] = is_conservation * (cur[4] + cur[8] - cur[6]);
255 result[STATE_W + 3] = is_conservation * cur[1];
256 result[STATE_W + 4] = is_conservation * cur[3];
257 result[STATE_W + 5] = is_conservation * cur[5];
258 result[STATE_W + 6] = is_conservation * cur[7];
259 result[STATE_W + 7] = is_conservation * cur[9];
260 }
261
262 fn get_assertions(&self) -> Vec<Assertion<BaseElement>> {
263 let mut assertions = Vec::new();
264
265 let digests = [
268 &self.pub_inputs.s_old,
269 &self.pub_inputs.s_new,
270 &self.pub_inputs.r_old,
271 &self.pub_inputs.r_new,
272 &self.pub_inputs.amt,
273 ];
274 for (h, digest) in digests.iter().enumerate() {
275 let output_row = hash_start(h) + ROWS_PER_HASH - 1;
276 for (d, &val) in digest.iter().enumerate() {
278 assertions.push(Assertion::single(4 + d, output_row, val));
279 }
280 }
281
282 assertions
283 }
284
285 fn get_periodic_column_values(&self) -> Vec<Vec<BaseElement>> {
286 let mut ark1_cols: Vec<Vec<BaseElement>> = (0..STATE_W)
292 .map(|_| vec![BaseElement::ZERO; TRACE_LEN])
293 .collect();
294 let mut ark2_cols: Vec<Vec<BaseElement>> = (0..STATE_W)
295 .map(|_| vec![BaseElement::ZERO; TRACE_LEN])
296 .collect();
297 let mut is_first_half_col = vec![BaseElement::ZERO; TRACE_LEN];
298 let mut is_conservation_col = vec![BaseElement::ZERO; TRACE_LEN];
299
300 for h in 0..NUM_HASHES {
301 for r in 0..NUM_ROUNDS {
302 let fwd_step = hash_start(h) + r * 2;
308 let bwd_step = fwd_step + 1;
309 for col in 0..STATE_W {
310 ark1_cols[col][fwd_step] = Rescue::ARK1[r][col];
312 ark2_cols[col][bwd_step] = Rescue::ARK2[r][col];
313 }
314 is_first_half_col[fwd_step] = BaseElement::ONE;
315 }
316 }
317 is_conservation_col[CONSERVATION_ROW] = BaseElement::ONE;
318
319 let mut is_active_col = vec![BaseElement::ZERO; TRACE_LEN];
322 for h in 0..NUM_HASHES {
323 for step in hash_start(h)..(hash_start(h) + ROWS_PER_HASH - 1) {
326 is_active_col[step] = BaseElement::ONE;
327 }
328 }
329 let mut cols: Vec<Vec<BaseElement>> = Vec::new();
333 cols.extend(ark1_cols);
334 cols.extend(ark2_cols);
335 cols.push(is_first_half_col);
336 cols.push(is_conservation_col);
337 cols.push(is_active_col);
338 cols
339 }
340}
341
342pub struct CtWitness {
352 pub s_old_bal: u128,
353 pub s_new_bal: u128,
354 pub r_old_bal: u128,
355 pub r_new_bal: u128,
356 pub amount: u128,
357 pub s_old_nonce: u128,
358 pub s_new_nonce: u128,
359 pub r_old_nonce: u128,
360 pub r_new_nonce: u128,
361 pub amt_nonce: u128,
362 pub ct_hash_s_old: [u8; 16],
364 pub ct_hash_s_new: [u8; 16],
365 pub ct_hash_r_old: [u8; 16],
366 pub ct_hash_r_new: [u8; 16],
367}
368
369fn u128_to_felts(v: u128) -> [BaseElement; 2] {
371 [
372 BaseElement::new(v as u64),
373 BaseElement::new((v >> 64) as u64),
374 ]
375}
376
377fn bytes16_to_felts(b: &[u8; 16]) -> [BaseElement; 2] {
379 let lo = u64::from_le_bytes(b[..8].try_into().unwrap());
380 let hi = u64::from_le_bytes(b[8..16].try_into().unwrap());
381 [BaseElement::new(lo), BaseElement::new(hi)]
382}
383
384fn build_rescue_input(bal: u128, nonce: u128, ct_hash: &[u8; 16]) -> [BaseElement; STATE_W] {
388 let mut state = [BaseElement::ZERO; STATE_W];
389 state[0] = BaseElement::new(6);
391 let bal_felts = u128_to_felts(bal);
393 let nonce_felts = u128_to_felts(nonce);
394 let ct_felts = bytes16_to_felts(ct_hash);
395 state[4] = bal_felts[0];
396 state[5] = bal_felts[1];
397 state[6] = nonce_felts[0];
398 state[7] = nonce_felts[1];
399 state[8] = ct_felts[0];
400 state[9] = ct_felts[1];
401 state
402}
403
404fn build_amount_rescue_input(amount: u128, nonce: u128) -> [BaseElement; STATE_W] {
407 let mut state = [BaseElement::ZERO; STATE_W];
408 state[0] = BaseElement::new(4); let amt_felts = u128_to_felts(amount);
410 let nonce_felts = u128_to_felts(nonce);
411 state[4] = amt_felts[0];
412 state[5] = amt_felts[1];
413 state[6] = nonce_felts[0];
414 state[7] = nonce_felts[1];
415 state
416}
417
418fn fill_rescue_trace(
421 trace: &mut TraceTable<BaseElement>,
422 start_row: usize,
423 initial_state: [BaseElement; STATE_W],
424 hash_idx: usize,
425) {
426 let mut state = initial_state;
427 for col in 0..STATE_W {
429 trace.set(col, start_row, state[col]);
430 }
431 trace.set(SEL_COL, start_row, BaseElement::new(hash_idx as u64));
432
433 let mut row = start_row;
434 for r in 0..NUM_ROUNDS {
435 for i in 0..STATE_W {
438 let s = state[i];
439 let s2 = s * s;
440 let s4 = s2 * s2;
441 state[i] = s4 * s2 * s;
442 }
443 let mut tmp = [BaseElement::ZERO; STATE_W];
445 for i in 0..STATE_W {
446 for j in 0..STATE_W {
447 tmp[i] += Rescue::MDS[i][j] * state[j];
448 }
449 }
450 state = tmp;
451 for i in 0..STATE_W {
453 state[i] += Rescue::ARK1[r][i];
454 }
455
456 row += 1;
458 for col in 0..STATE_W {
459 trace.set(col, row, state[col]);
460 }
461 trace.set(SEL_COL, row, BaseElement::new(hash_idx as u64));
462
463 const INV_ALPHA: u64 = 10540996611094048183;
467 for i in 0..STATE_W {
468 state[i] = state[i].exp(INV_ALPHA.into());
469 }
470 let mut tmp = [BaseElement::ZERO; STATE_W];
472 for i in 0..STATE_W {
473 for j in 0..STATE_W {
474 tmp[i] += Rescue::MDS[i][j] * state[j];
475 }
476 }
477 state = tmp;
478 for i in 0..STATE_W {
480 state[i] += Rescue::ARK2[r][i];
481 }
482
483 row += 1;
484 for col in 0..STATE_W {
485 trace.set(col, row, state[col]);
486 }
487 trace.set(SEL_COL, row, BaseElement::new(hash_idx as u64));
488 }
489}
490
491pub fn build_trace(w: &CtWitness) -> TraceTable<BaseElement> {
493 let mut trace = TraceTable::new(TRACE_W, TRACE_LEN);
494
495 fill_rescue_trace(
497 &mut trace,
498 hash_start(0),
499 build_rescue_input(w.s_old_bal, w.s_old_nonce, &w.ct_hash_s_old),
500 0,
501 );
502 fill_rescue_trace(
504 &mut trace,
505 hash_start(1),
506 build_rescue_input(w.s_new_bal, w.s_new_nonce, &w.ct_hash_s_new),
507 1,
508 );
509 fill_rescue_trace(
511 &mut trace,
512 hash_start(2),
513 build_rescue_input(w.r_old_bal, w.r_old_nonce, &w.ct_hash_r_old),
514 2,
515 );
516 fill_rescue_trace(
518 &mut trace,
519 hash_start(3),
520 build_rescue_input(w.r_new_bal, w.r_new_nonce, &w.ct_hash_r_new),
521 3,
522 );
523 fill_rescue_trace(
525 &mut trace,
526 hash_start(4),
527 build_amount_rescue_input(w.amount, w.amt_nonce),
528 4,
529 );
530
531 let row = CONSERVATION_ROW;
533 trace.set(0, row, BaseElement::new(w.s_old_bal as u64));
534 trace.set(1, row, BaseElement::new((w.s_old_bal >> 64) as u64));
535 trace.set(2, row, BaseElement::new(w.s_new_bal as u64));
536 trace.set(3, row, BaseElement::new((w.s_new_bal >> 64) as u64));
537 trace.set(4, row, BaseElement::new(w.r_old_bal as u64));
538 trace.set(5, row, BaseElement::new((w.r_old_bal >> 64) as u64));
539 trace.set(6, row, BaseElement::new(w.r_new_bal as u64));
540 trace.set(7, row, BaseElement::new((w.r_new_bal >> 64) as u64));
541 trace.set(8, row, BaseElement::new(w.amount as u64));
542 trace.set(9, row, BaseElement::new((w.amount >> 64) as u64));
543 trace.set(SEL_COL, row, BaseElement::new(99)); trace
546}
547
548pub struct CtProver {
553 options: ProofOptions,
554}
555
556impl CtProver {
557 pub fn new() -> Self {
558 Self {
559 options: ProofOptions::new(
560 40, 8, 20, FieldExtension::None,
564 8, 255, BatchingMethod::Algebraic,
567 BatchingMethod::Algebraic,
568 ),
569 }
570 }
571
572 pub fn prove(&self, witness: &CtWitness) -> Result<(Vec<u8>, CtPublicInputs), String> {
575 let s_old = witness.s_old_bal;
577 let s_new = witness.s_new_bal;
578 let r_old = witness.r_old_bal;
579 let r_new = witness.r_new_bal;
580 let amt = witness.amount;
581
582 if [s_old, s_new, r_old, r_new, amt]
583 .iter()
584 .any(|&value| value > u64::MAX as u128)
585 {
586 return Err("Confidential transfer balances and amount must fit in u64".into());
587 }
588 if amt == 0 {
589 return Err("Amount must be > 0".into());
590 }
591 let expected_s_new = s_old
592 .checked_sub(amt)
593 .ok_or("Sender balance underflow: old_balance < amount")?;
594 if s_new != expected_s_new {
595 return Err(format!("Conservation violated: {s_old} - {amt} != {s_new}"));
596 }
597 let expected_r_new = r_old.checked_add(amt).ok_or("Recipient balance overflow")?;
598 if r_new != expected_r_new {
599 return Err(format!("Conservation violated: {r_old} + {amt} != {r_new}"));
600 }
601
602 let s_old_digest = rescue_commit(s_old, witness.s_old_nonce, &witness.ct_hash_s_old);
604 let s_new_digest = rescue_commit(s_new, witness.s_new_nonce, &witness.ct_hash_s_new);
605 let r_old_digest = rescue_commit(r_old, witness.r_old_nonce, &witness.ct_hash_r_old);
606 let r_new_digest = rescue_commit(r_new, witness.r_new_nonce, &witness.ct_hash_r_new);
607 let amt_digest = rescue_commit_amount(amt, witness.amt_nonce);
608
609 let pub_inputs = CtPublicInputs {
610 s_old: s_old_digest,
611 s_new: s_new_digest,
612 r_old: r_old_digest,
613 r_new: r_new_digest,
614 amt: amt_digest,
615 };
616
617 let trace = build_trace(witness);
618 let proof =
619 Prover::prove(self, trace).map_err(|e| format!("Proof generation failed: {e}"))?;
620
621 let proof_bytes = proof.to_bytes();
622 Ok((proof_bytes, pub_inputs))
623 }
624}
625
626impl Prover for CtProver {
627 type BaseField = BaseElement;
628 type Air = CtAir;
629 type Trace = TraceTable<BaseElement>;
630 type HashFn = Rp64_256;
631 type VC = MerkleTree<Rp64_256>;
632 type RandomCoin = DefaultRandomCoin<Rp64_256>;
633 type TraceLde<E: FieldElement<BaseField = BaseElement>> =
634 DefaultTraceLde<E, Rp64_256, MerkleTree<Rp64_256>>;
635 type ConstraintEvaluator<'a, E: FieldElement<BaseField = BaseElement>> =
636 DefaultConstraintEvaluator<'a, CtAir, E>;
637 type ConstraintCommitment<E: FieldElement<BaseField = BaseElement>> =
638 DefaultConstraintCommitment<E, Rp64_256, MerkleTree<Rp64_256>>;
639
640 fn get_pub_inputs(&self, trace: &Self::Trace) -> CtPublicInputs {
641 let digest_of = |h: usize| -> [BaseElement; 4] {
643 let row = hash_start(h) + ROWS_PER_HASH - 1;
644 [
645 trace.get(4, row),
646 trace.get(5, row),
647 trace.get(6, row),
648 trace.get(7, row),
649 ]
650 };
651 CtPublicInputs {
652 s_old: digest_of(0),
653 s_new: digest_of(1),
654 r_old: digest_of(2),
655 r_new: digest_of(3),
656 amt: digest_of(4),
657 }
658 }
659
660 fn options(&self) -> &ProofOptions {
661 &self.options
662 }
663
664 fn new_trace_lde<E: FieldElement<BaseField = BaseElement>>(
665 &self,
666 trace_info: &TraceInfo,
667 main_trace: &ColMatrix<BaseElement>,
668 domain: &StarkDomain<BaseElement>,
669 partition_option: winterfell::PartitionOptions,
670 ) -> (Self::TraceLde<E>, TracePolyTable<E>) {
671 DefaultTraceLde::new(trace_info, main_trace, domain, partition_option)
672 }
673
674 fn new_evaluator<'a, E: FieldElement<BaseField = BaseElement>>(
675 &self,
676 air: &'a Self::Air,
677 aux_rand_elements: Option<winterfell::AuxRandElements<E>>,
678 composition_coefficients: winterfell::ConstraintCompositionCoefficients<E>,
679 ) -> Self::ConstraintEvaluator<'a, E> {
680 DefaultConstraintEvaluator::new(air, aux_rand_elements, composition_coefficients)
681 }
682
683 fn build_constraint_commitment<E: FieldElement<BaseField = BaseElement>>(
684 &self,
685 composition_poly_trace: winterfell::CompositionPolyTrace<E>,
686 num_constraint_composition_columns: usize,
687 domain: &StarkDomain<BaseElement>,
688 partition_options: winterfell::PartitionOptions,
689 ) -> (
690 Self::ConstraintCommitment<E>,
691 winterfell::CompositionPoly<E>,
692 ) {
693 DefaultConstraintCommitment::new(
694 composition_poly_trace,
695 num_constraint_composition_columns,
696 domain,
697 partition_options,
698 )
699 }
700}
701
702pub fn rescue_commit(bal: u128, nonce: u128, ct_hash: &[u8; 16]) -> [BaseElement; 4] {
709 let mut state = build_rescue_input(bal, nonce, ct_hash);
710 Rescue::apply_permutation(&mut state);
711 [state[4], state[5], state[6], state[7]]
712}
713
714pub fn rescue_commit_amount(amount: u128, nonce: u128) -> [BaseElement; 4] {
716 let mut state = build_amount_rescue_input(amount, nonce);
717 Rescue::apply_permutation(&mut state);
718 [state[4], state[5], state[6], state[7]]
719}
720
721pub fn digest_to_bytes(d: &[BaseElement; 4]) -> [u8; 32] {
723 let mut out = [0u8; 32];
724 for (i, e) in d.iter().enumerate() {
725 out[i * 8..(i + 1) * 8].copy_from_slice(&e.as_int().to_le_bytes());
726 }
727 out
728}
729
730pub fn bytes_to_digest(b: &[u8; 32]) -> [BaseElement; 4] {
732 let mut d = [BaseElement::ZERO; 4];
733 for i in 0..4 {
734 let v = u64::from_le_bytes(b[i * 8..(i + 1) * 8].try_into().unwrap());
735 d[i] = BaseElement::new(v);
736 }
737 d
738}
739
740pub fn verify_ct_proof(proof_bytes: &[u8], pub_inputs: &CtPublicInputs) -> Result<(), String> {
750 if proof_bytes.is_empty() {
751 return Err("Empty proof".into());
752 }
753 if proof_bytes.len() > 1024 * 1024 {
754 return Err("Proof too large".into());
755 }
756
757 let proof =
758 Proof::from_bytes(proof_bytes).map_err(|e| format!("Proof deserialization failed: {e}"))?;
759
760 let acceptable = AcceptableOptions::OptionSet(vec![ProofOptions::new(
761 40,
762 8,
763 20,
764 FieldExtension::None,
765 8,
766 255,
767 BatchingMethod::Algebraic,
768 BatchingMethod::Algebraic,
769 )]);
770
771 winterfell::verify::<CtAir, Rp64_256, DefaultRandomCoin<Rp64_256>, MerkleTree<Rp64_256>>(
772 proof,
773 pub_inputs.clone(),
774 &acceptable,
775 )
776 .map_err(|e| format!("STARK verification failed: {e}"))
777}
778
779#[cfg(test)]
780mod zk_tests {
781 use super::*;
782
783 fn ct_hash(ct: &[u8]) -> [u8; 16] {
784 let h = blake3::hash(ct);
785 let mut out = [0u8; 16];
786 out.copy_from_slice(&h.as_bytes()[..16]);
787 out
788 }
789
790 #[test]
791 fn prove_and_verify_valid_transfer() {
792 let s_old = 1_000_000_000u128;
793 let amount = 500_000_000u128;
794 let s_new = s_old - amount;
795 let r_old = 200_000_000u128;
796 let r_new = r_old + amount;
797
798 let dummy_ct = [0xABu8; 44];
799
800 let witness = CtWitness {
801 s_old_bal: s_old,
802 s_new_bal: s_new,
803 r_old_bal: r_old,
804 r_new_bal: r_new,
805 amount,
806 s_old_nonce: 0xDEADBEEF_01020304u128,
807 s_new_nonce: 0xDEADBEEF_05060708u128,
808 r_old_nonce: 0xCAFEBABE_01020304u128,
809 r_new_nonce: 0xCAFEBABE_05060708u128,
810 amt_nonce: 0x1234567890ABCDEFu128,
811 ct_hash_s_old: ct_hash(&dummy_ct),
812 ct_hash_s_new: ct_hash(&[0xCCu8; 44]),
813 ct_hash_r_old: ct_hash(&[0xDDu8; 44]),
814 ct_hash_r_new: ct_hash(&[0xEEu8; 44]),
815 };
816
817 let prover = CtProver::new();
818 let (proof_bytes, pub_inputs) = prover.prove(&witness).expect("Proof generation failed");
819
820 assert!(!proof_bytes.is_empty(), "Proof must not be empty");
821 assert_eq!(pub_inputs.to_elements().len(), 4 * NUM_HASHES);
822
823 verify_ct_proof(&proof_bytes, &pub_inputs).expect("Proof verification failed");
824 }
825
826 #[test]
827 fn prover_rejects_conservation_violation() {
828 let witness = CtWitness {
829 s_old_bal: 1_000u128,
830 s_new_bal: 600u128, r_old_bal: 200u128,
832 r_new_bal: 500u128,
833 amount: 300u128,
834 s_old_nonce: 1,
835 s_new_nonce: 2,
836 r_old_nonce: 3,
837 r_new_nonce: 4,
838 amt_nonce: 5,
839 ct_hash_s_old: [0u8; 16],
840 ct_hash_s_new: [0u8; 16],
841 ct_hash_r_old: [0u8; 16],
842 ct_hash_r_new: [0u8; 16],
843 };
844 let err = CtProver::new().prove(&witness).unwrap_err();
845 assert!(err.contains("Conservation violated"), "got: {err}");
846 }
847
848 #[test]
849 fn prover_rejects_zero_amount() {
850 let witness = CtWitness {
851 s_old_bal: 1000,
852 s_new_bal: 1000,
853 r_old_bal: 0,
854 r_new_bal: 0,
855 amount: 0,
856 s_old_nonce: 1,
857 s_new_nonce: 2,
858 r_old_nonce: 3,
859 r_new_nonce: 4,
860 amt_nonce: 5,
861 ct_hash_s_old: [0u8; 16],
862 ct_hash_s_new: [0u8; 16],
863 ct_hash_r_old: [0u8; 16],
864 ct_hash_r_new: [0u8; 16],
865 };
866 let err = CtProver::new().prove(&witness).unwrap_err();
867 assert!(err.contains("Amount must be > 0"), "got: {err}");
868 }
869
870 #[test]
871 fn prover_rejects_underflow() {
872 let witness = CtWitness {
873 s_old_bal: 100,
874 s_new_bal: 0,
875 r_old_bal: 0,
876 r_new_bal: 200,
877 amount: 200, s_old_nonce: 1,
879 s_new_nonce: 2,
880 r_old_nonce: 3,
881 r_new_nonce: 4,
882 amt_nonce: 5,
883 ct_hash_s_old: [0u8; 16],
884 ct_hash_s_new: [0u8; 16],
885 ct_hash_r_old: [0u8; 16],
886 ct_hash_r_new: [0u8; 16],
887 };
888 let err = CtProver::new().prove(&witness).unwrap_err();
889 assert!(err.contains("underflow"), "got: {err}");
890 }
891
892 #[test]
893 fn tampered_proof_fails_verification() {
894 let witness = CtWitness {
895 s_old_bal: 1000,
896 s_new_bal: 700,
897 r_old_bal: 200,
898 r_new_bal: 500,
899 amount: 300,
900 s_old_nonce: 11,
901 s_new_nonce: 22,
902 r_old_nonce: 33,
903 r_new_nonce: 44,
904 amt_nonce: 55,
905 ct_hash_s_old: [1u8; 16],
906 ct_hash_s_new: [2u8; 16],
907 ct_hash_r_old: [3u8; 16],
908 ct_hash_r_new: [4u8; 16],
909 };
910 let (proof_bytes, mut pub_inputs) = CtProver::new().prove(&witness).unwrap();
911
912 pub_inputs.amt[0] += BaseElement::ONE;
916
917 let err = verify_ct_proof(&proof_bytes, &pub_inputs).unwrap_err();
918 assert!(
919 err.contains("verification") || err.contains("failed"),
920 "got: {err}"
921 );
922 }
923
924 #[test]
925 fn rescue_commitment_is_deterministic() {
926 let d1 = rescue_commit(12345, 99999, &[0xABu8; 16]);
927 let d2 = rescue_commit(12345, 99999, &[0xABu8; 16]);
928 assert_eq!(d1, d2);
929 }
930
931 #[test]
932 fn rescue_commitment_differs_on_different_inputs() {
933 let d1 = rescue_commit(12345, 99999, &[0xABu8; 16]);
934 let d2 = rescue_commit(12346, 99999, &[0xABu8; 16]); assert_ne!(d1, d2);
936 let d3 = rescue_commit(12345, 99999, &[0xACu8; 16]); assert_ne!(d1, d3);
938 }
939
940 #[test]
941 fn digest_roundtrip() {
942 let d = rescue_commit(999, 777, &[0x55u8; 16]);
943 let b = digest_to_bytes(&d);
944 let d2 = bytes_to_digest(&b);
945 assert_eq!(d, d2);
946 }
947}