1#![no_std]
17#![doc = include_str!("../README.md")]
18#![deny(missing_docs)]
19
20use core::ops::{Add, Sub};
21use jub_jub::{Fp, JubjubAffine, JubjubExtended};
22use num_traits::{CheckedAdd, CheckedSub};
23use parity_scale_codec::{Decode, Encode};
24use serde::{Deserialize, Serialize};
25use zkstd::common::{TwistedEdwardsAffine, TwistedEdwardsCurve};
26
27#[derive(Debug, Clone, Copy, Encode, Decode, PartialEq, Eq, Deserialize, Serialize)]
30pub struct EncryptedNumber {
31 s: JubjubAffine,
32 t: JubjubAffine,
33}
34
35impl Default for EncryptedNumber {
36 fn default() -> Self {
37 Self {
38 s: JubjubAffine::ADDITIVE_IDENTITY,
39 t: JubjubAffine::ADDITIVE_IDENTITY,
40 }
41 }
42}
43
44impl EncryptedNumber {
47 pub fn new(s: JubjubAffine, t: JubjubAffine) -> Self {
49 Self { s, t }
50 }
51
52 pub fn encrypt(private_key: Fp, value: u32, random: Fp) -> Self {
54 let g = JubjubExtended::ADDITIVE_GENERATOR;
55 let public_key = g * private_key;
56 let left = g * Fp::from(value as u64) + public_key * random;
57 EncryptedNumber {
58 s: JubjubAffine::from(left),
59 t: JubjubAffine::from(g * random),
60 }
61 }
62
63 pub fn decrypt(&self, private_key: Fp) -> Option<u32> {
65 let g = JubjubExtended::ADDITIVE_GENERATOR;
66 let decrypted_message =
67 JubjubExtended::from(self.s) - (JubjubExtended::from(self.t) * private_key);
68
69 let mut acc = JubjubExtended::ADDITIVE_IDENTITY;
70 for i in 0..150000 {
71 if acc == decrypted_message {
72 return Some(i);
73 }
74 acc += g;
75 }
76 None
77 }
78
79 pub fn get_coordinate(self) -> (JubjubAffine, JubjubAffine) {
81 (self.s, self.t)
82 }
83}
84
85impl Add for EncryptedNumber {
86 type Output = Self;
87 #[inline]
88 fn add(self, rhs: Self) -> Self::Output {
89 Self {
90 s: JubjubAffine::from(JubjubExtended::from(self.s) + JubjubExtended::from(rhs.s)),
91 t: JubjubAffine::from(JubjubExtended::from(self.t) + JubjubExtended::from(rhs.t)),
92 }
93 }
94}
95
96impl Sub for EncryptedNumber {
97 type Output = Self;
98
99 #[inline]
100 fn sub(self, rhs: Self) -> Self::Output {
101 Self {
102 s: JubjubAffine::from(JubjubExtended::from(self.s) - JubjubExtended::from(rhs.s)),
103 t: JubjubAffine::from(JubjubExtended::from(self.t) - JubjubExtended::from(rhs.t)),
104 }
105 }
106}
107
108impl CheckedAdd for EncryptedNumber {
109 #[inline]
110 fn checked_add(&self, rhs: &Self) -> Option<Self> {
111 Some(Self {
112 s: JubjubAffine::from(JubjubExtended::from(self.s) + JubjubExtended::from(rhs.s)),
113 t: JubjubAffine::from(JubjubExtended::from(self.t) + JubjubExtended::from(rhs.t)),
114 })
115 }
116}
117
118impl CheckedSub for EncryptedNumber {
119 #[inline]
120 fn checked_sub(&self, rhs: &Self) -> Option<Self> {
121 Some(Self {
122 s: JubjubAffine::from(JubjubExtended::from(self.s) - JubjubExtended::from(rhs.s)),
123 t: JubjubAffine::from(JubjubExtended::from(self.t) - JubjubExtended::from(rhs.t)),
124 })
125 }
126}
127
128pub trait ConfidentialTransferPublicInputs<A: TwistedEdwardsAffine> {
130 fn init(s: A, t: A) -> Self;
132
133 fn get(self) -> (A, A);
135}
136
137impl ConfidentialTransferPublicInputs<JubjubAffine> for EncryptedNumber {
138 fn init(s: JubjubAffine, t: JubjubAffine) -> Self {
139 Self::new(s, t)
140 }
141
142 fn get(self) -> (JubjubAffine, JubjubAffine) {
143 self.get_coordinate()
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use jub_jub::Fp;
150 use rand::{thread_rng, Rng};
151 use rand_core::OsRng;
152 use zkstd::common::*;
153
154 use crate::EncryptedNumber;
155
156 fn arb_fr() -> Fp {
157 Fp::random(OsRng)
158 }
159
160 #[test]
161 fn test_encrypt_decrypt() {
162 let priv_k = arb_fr();
163 let random = arb_fr();
164 let balance = thread_rng().gen::<u16>();
165 let enc_balance = EncryptedNumber::encrypt(priv_k, balance as u32, random);
166
167 let dec_balance = enc_balance.decrypt(priv_k);
168 assert_eq!(dec_balance.unwrap(), balance as u32);
169 }
170
171 #[test]
172 fn test_homomorphic() {
173 let priv_k = arb_fr();
174 let random1 = arb_fr();
175 let random2 = arb_fr();
176 let balance1 = thread_rng().gen::<u16>();
177 let balance2 = thread_rng().gen::<u16>();
178 let (balance1, balance2) = if balance1 > balance2 {
179 (balance1 as u32, balance2 as u32)
180 } else {
181 (balance2 as u32, balance1 as u32)
182 };
183
184 let enc_balance1 = EncryptedNumber::encrypt(priv_k, balance1, random1);
185 let enc_balance2 = EncryptedNumber::encrypt(priv_k, balance2, random2);
186 let enc_sub = enc_balance1 - enc_balance2;
187 let enc_add = enc_balance1 + enc_balance2;
188
189 let dec_sub = enc_sub.decrypt(priv_k);
190 let dec_add = enc_add.decrypt(priv_k);
191
192 assert_eq!(dec_sub.unwrap(), balance1 - balance2);
193 assert_eq!(dec_add.unwrap(), balance1 + balance2);
194 }
195
196 #[test]
197 fn test_elgamal() {
198 let alice_pk = arb_fr();
199 let bob_pk = arb_fr();
200 let alice_balance = thread_rng().gen::<u16>();
201 let bob_balance = thread_rng().gen::<u16>();
202 let transfer_amount = thread_rng().gen::<u16>();
203 let alice_randomness = thread_rng().gen::<u64>();
204 let bob_randomness = thread_rng().gen::<u64>();
205 let alice_transfer_randomness = thread_rng().gen::<u64>();
206
207 let (alice_balance, transfer_amount) = if alice_balance > transfer_amount {
208 (alice_balance as u32, transfer_amount as u32)
209 } else {
210 (transfer_amount as u32, alice_balance as u32)
211 };
212 let bob_balance = bob_balance as u32;
213
214 let (alice_randomness, alice_transfer_randomness) =
216 if alice_randomness > alice_transfer_randomness {
217 (alice_randomness, alice_transfer_randomness)
218 } else {
219 (alice_transfer_randomness, alice_randomness)
220 };
221 let alice_randomness = Fp::from(alice_randomness);
222 let bob_randomness = Fp::from(bob_randomness);
223 let alice_transfer_randomness = Fp::from(alice_transfer_randomness);
224
225 let alice_balance_enc = EncryptedNumber::encrypt(alice_pk, alice_balance, alice_randomness);
226 let bob_balance_enc = EncryptedNumber::encrypt(bob_pk, bob_balance, bob_randomness);
227
228 let transfer_amount_enc_alice =
229 EncryptedNumber::encrypt(alice_pk, transfer_amount, alice_transfer_randomness);
230 let transfer_amount_enc_bob =
231 EncryptedNumber::encrypt(bob_pk, transfer_amount, alice_transfer_randomness);
232
233 let alice_after_balance_enc = alice_balance_enc - transfer_amount_enc_alice;
234 let bob_after_balance_enc = bob_balance_enc + transfer_amount_enc_bob;
235
236 let alice_randomness_sum = alice_randomness - alice_transfer_randomness;
237 let bob_randomness_sum = bob_randomness + alice_transfer_randomness;
238
239 let explicit_alice = alice_balance - transfer_amount;
240 let explicit_bob = bob_balance + transfer_amount;
241 let exp_alice_balance_enc =
242 EncryptedNumber::encrypt(alice_pk, explicit_alice, alice_randomness_sum);
243 let exp_bob_balance_enc =
244 EncryptedNumber::encrypt(bob_pk, explicit_bob, bob_randomness_sum);
245
246 assert_eq!(exp_alice_balance_enc.t, alice_after_balance_enc.t);
247 assert_eq!(exp_bob_balance_enc, bob_after_balance_enc);
248 }
249}