1use crate::error::SemaphoreError;
8use ark_ed_on_bn254::Fq;
9use ark_ff::{BigInteger, PrimeField};
10use lean_imt::hashed_tree::{HashedLeanIMT, LeanIMTHasher};
11use light_poseidon::{Poseidon, PoseidonHasher};
12
13pub const ELEMENT_SIZE: usize = 32;
15pub const EMPTY_ELEMENT: Element = [0u8; ELEMENT_SIZE];
17
18pub type Element = [u8; ELEMENT_SIZE];
20
21pub type MerkleProof = lean_imt::lean_imt::MerkleProof<ELEMENT_SIZE>;
23
24#[derive(Debug, Default, Clone, PartialEq, Eq)]
26pub struct PoseidonHash;
27
28impl LeanIMTHasher<ELEMENT_SIZE> for PoseidonHash {
29 fn hash(input: &[u8]) -> [u8; ELEMENT_SIZE] {
30 let hash = Poseidon::<Fq>::new_circom(2)
31 .expect("Failed to initialize Poseidon")
32 .hash(&[
33 Fq::from_le_bytes_mod_order(&input[..ELEMENT_SIZE]),
34 Fq::from_le_bytes_mod_order(&input[ELEMENT_SIZE..]),
35 ])
36 .expect("Poseidon hash failed");
37
38 let mut hash_bytes = [0u8; ELEMENT_SIZE];
39 hash_bytes.copy_from_slice(&hash.into_bigint().to_bytes_le());
40
41 hash_bytes
42 }
43}
44
45#[derive(Debug, Default, Clone, PartialEq, Eq)]
46pub struct Group {
47 pub tree: HashedLeanIMT<ELEMENT_SIZE, PoseidonHash>,
49}
50
51impl Group {
52 pub fn new(members: &[Element]) -> Result<Self, SemaphoreError> {
54 if members.is_empty() {
55 return Ok(Group {
56 tree: HashedLeanIMT::<ELEMENT_SIZE, PoseidonHash>::new(&[], PoseidonHash)?,
57 });
58 }
59
60 for &member in members {
61 if member == EMPTY_ELEMENT {
62 return Err(SemaphoreError::EmptyLeaf);
63 }
64 }
65
66 Ok(Group {
67 tree: HashedLeanIMT::<ELEMENT_SIZE, PoseidonHash>::new(members, PoseidonHash)?,
68 })
69 }
70
71 pub fn root(&self) -> Option<Element> {
73 self.tree.root()
74 }
75
76 pub fn depth(&self) -> usize {
78 self.tree.depth()
79 }
80
81 pub fn size(&self) -> usize {
83 self.tree.size()
84 }
85
86 pub fn members(&self) -> Vec<Element> {
88 self.tree
89 .leaves()
90 .iter()
91 .map(|v| v.as_slice().try_into().unwrap())
92 .collect()
93 }
94
95 pub fn index_of(&self, member: Element) -> Option<usize> {
97 self.tree.index_of(&member)
98 }
99
100 pub fn add_member(&mut self, member: Element) -> Result<(), SemaphoreError> {
102 if member == EMPTY_ELEMENT {
103 return Err(SemaphoreError::EmptyLeaf);
104 }
105
106 self.tree.insert(&member);
107 Ok(())
108 }
109
110 pub fn add_members(&mut self, members: &[Element]) -> Result<(), SemaphoreError> {
112 for &member in members {
113 if member == EMPTY_ELEMENT {
114 return Err(SemaphoreError::EmptyLeaf);
115 }
116 }
117
118 self.tree.insert_many(members)?;
119 Ok(())
120 }
121
122 pub fn update_member(&mut self, index: usize, member: Element) -> Result<(), SemaphoreError> {
124 if self.members()[index] == EMPTY_ELEMENT {
125 return Err(SemaphoreError::RemovedMember);
126 }
127
128 self.tree.update(index, &member)?;
129 Ok(())
130 }
131
132 pub fn remove_member(&mut self, index: usize) -> Result<(), SemaphoreError> {
134 if self.members()[index] == EMPTY_ELEMENT {
135 return Err(SemaphoreError::AlreadyRemovedMember);
136 }
137
138 self.tree.update(index, &EMPTY_ELEMENT)?;
139 Ok(())
140 }
141
142 pub fn generate_proof(&self, index: usize) -> Result<MerkleProof, SemaphoreError> {
144 self.tree
145 .generate_proof(index)
146 .map_err(SemaphoreError::LeanIMTError)
147 }
148
149 pub fn verify_proof(proof: &MerkleProof) -> bool {
151 HashedLeanIMT::<ELEMENT_SIZE, PoseidonHash>::verify_proof(proof)
152 }
153}
154
155#[cfg(feature = "serde")]
156impl Group {
157 pub fn export(&self) -> Result<String, SemaphoreError> {
159 serde_json::to_string(&self.tree.tree())
160 .map_err(|e| SemaphoreError::SerializationError(e.to_string()))
161 }
162
163 pub fn import(json: &str) -> Result<Self, SemaphoreError> {
165 let lean_imt_tree: lean_imt::lean_imt::LeanIMT<ELEMENT_SIZE> =
166 serde_json::from_str(json)
167 .map_err(|e| SemaphoreError::SerializationError(e.to_string()))?;
168
169 Ok(Group {
170 tree: HashedLeanIMT::new_from_tree(lean_imt_tree, PoseidonHash),
171 })
172 }
173}
174
175pub fn bytes_to_element(bytes: &[u8]) -> Result<Element, SemaphoreError> {
177 if bytes.len() > ELEMENT_SIZE {
178 return Err(SemaphoreError::InputSizeExceeded(bytes.len()));
179 }
180
181 let mut element = EMPTY_ELEMENT;
182 element[..bytes.len()].copy_from_slice(bytes);
183
184 Ok(element)
185}
186
187pub fn fq_to_element(fq: &Fq) -> Element {
189 let mut element = EMPTY_ELEMENT;
190 let bytes = fq.into_bigint().to_bytes_le();
191 element[..bytes.len()].copy_from_slice(&bytes);
192 element
193}
194
195pub fn element_to_fq(element: &Element) -> Fq {
197 Fq::from_le_bytes_mod_order(element)
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_conversions() {
206 let test_bytes = [
207 59, 227, 30, 252, 212, 244, 251, 255, 228, 174, 31, 212, 161, 61, 184, 169, 200, 50, 7,
208 84, 65, 96,
209 ];
210 let element = bytes_to_element(&test_bytes).unwrap();
211 let fq = element_to_fq(&element);
212 let element_back = fq_to_element(&fq);
213
214 assert_eq!(element, element_back);
215 assert_eq!(fq, Fq::from_le_bytes_mod_order(&test_bytes));
216 assert_eq!(
217 bytes_to_element(&[0; 33]),
218 Err(SemaphoreError::InputSizeExceeded(33))
219 );
220 }
221
222 #[test]
223 fn test_create_empty_group() {
224 let group = Group::default();
225
226 assert_eq!(group.root(), None);
227 assert_eq!(group.depth(), 0);
228 assert_eq!(group.size(), 0);
229 }
230
231 #[test]
232 fn test_create_group_with_members() {
233 let member1 = [1; 32];
234 let member2 = [2; 32];
235 let member3 = [3; 32];
236
237 let group1 = Group::new(&[member1, member2, member3]).unwrap();
238
239 let mut group2 = Group::default();
240 group2.add_member(member1).unwrap();
241 group2.add_member(member2).unwrap();
242 group2.add_member(member3).unwrap();
243
244 assert_eq!(group1.root(), group2.root());
245 assert_eq!(group1.depth(), 2);
246 assert_eq!(group1.size(), 3);
247 }
248
249 #[test]
250 fn test_create_group_with_zero_member() {
251 let member1 = [1; 32];
252 let zero = [0u8; ELEMENT_SIZE];
253
254 let result = Group::new(&[member1, zero]);
255
256 assert!(result.is_err());
257 assert_eq!(result, Err(SemaphoreError::EmptyLeaf));
258 }
259
260 #[test]
261 fn test_add_member() {
262 let mut group = Group::default();
263 let member = [1; 32];
264 group.add_member(member).unwrap();
265
266 assert_eq!(group.size(), 1);
267 }
268
269 #[test]
270 fn test_add_zero_member() {
271 let mut group = Group::default();
272 let zero = [0u8; ELEMENT_SIZE];
273 let result = group.add_member(zero);
274
275 assert!(result.is_err());
276 assert_eq!(result, Err(SemaphoreError::EmptyLeaf));
277 }
278
279 #[test]
280 fn test_add_members() {
281 let mut group = Group::default();
282 let member1 = [1; 32];
283 let member2 = [2; 32];
284
285 group.add_members(&[member1, member2]).unwrap();
286
287 assert_eq!(group.size(), 2);
288 }
289
290 #[test]
291 fn test_add_members_with_zero() {
292 let mut group = Group::default();
293 let member1 = [1; 32];
294 let zero = [0u8; ELEMENT_SIZE];
295
296 let result = group.add_members(&[member1, zero]);
297
298 assert!(result.is_err());
299 assert_eq!(result, Err(SemaphoreError::EmptyLeaf));
300 }
301
302 #[test]
303 fn test_index_of() {
304 let member1 = [1; 32];
305 let member2 = [2; 32];
306 let mut group = Group::default();
307
308 group.add_members(&[member1, member2]).unwrap();
309 let index = group.index_of(member2);
310
311 assert_eq!(index, Some(1));
312 }
313
314 #[test]
315 fn test_update_member() {
316 let member1 = [1; 32];
317 let member2 = [2; 32];
318 let mut group = Group::default();
319
320 group.add_members(&[member1, member2]).unwrap();
321
322 group.update_member(0, member1).unwrap();
323 assert_eq!(group.size(), 2);
324
325 let members = group.members();
326 assert_eq!(members[0], member1);
327 }
328
329 #[test]
330 fn test_update_removed_member() {
331 let member1 = [1; 32];
332 let member2 = [2; 32];
333 let mut group = Group::default();
334
335 group.add_members(&[member1, member2]).unwrap();
336 group.remove_member(0).unwrap();
337
338 let result = group.update_member(0, member1);
339 assert!(result.is_err());
340 assert_eq!(result, Err(SemaphoreError::RemovedMember));
341 }
342
343 #[test]
344 fn test_remove_member() {
345 let member1 = [1; 32];
346 let member2 = [2; 32];
347 let mut group = Group::default();
348
349 group.add_members(&[member1, member2]).unwrap();
350 group.remove_member(0).unwrap();
351
352 let members = group.members();
353 assert_eq!(members[0], [0u8; ELEMENT_SIZE]);
354 assert_eq!(group.size(), 2);
355 }
356
357 #[test]
358 fn test_remove_member_already_removed() {
359 let member1 = [1; 32];
360 let member2 = [2; 32];
361 let mut group = Group::default();
362
363 group.add_members(&[member1, member2]).unwrap();
364 group.remove_member(0).unwrap();
365
366 let result = group.remove_member(0);
367
368 assert!(result.is_err());
369 assert_eq!(result, Err(SemaphoreError::AlreadyRemovedMember));
370 }
371
372 #[test]
373 fn test_generate_merkle_proof() {
374 let member1 = [1; 32];
375 let member2 = [2; 32];
376 let mut group = Group::default();
377
378 group.add_members(&[member1, member2]).unwrap();
379
380 let proof = group.generate_proof(0).unwrap();
381 assert_eq!(proof.leaf, member1);
382 }
383
384 #[test]
385 fn test_verify_proof() {
386 let member1 = [1; 32];
387 let member2 = [2; 32];
388 let mut group = Group::default();
389
390 group.add_members(&[member1, member2]).unwrap();
391
392 let proof_0 = group.generate_proof(0).unwrap();
393 assert_eq!(Group::verify_proof(&proof_0), true);
394
395 let mut proof_1 = group.generate_proof(1).unwrap();
396 assert_eq!(Group::verify_proof(&proof_1), true);
397
398 proof_1.leaf = member1;
399 assert_eq!(Group::verify_proof(&proof_1), false);
400 }
401
402 #[cfg(feature = "serde")]
403 #[test]
404 fn test_export_import() {
405 let member1 = [1; 32];
406 let member2 = [2; 32];
407 let member3 = [3; 32];
408 let group = Group::new(&[member1, member2, member3]).unwrap();
409
410 let json = group.export().unwrap();
411 let imported_group = Group::import(&json).unwrap();
412
413 assert_eq!(group, imported_group);
414 }
415}