semaphore/
group.rs

1//! Group module
2//!
3//! This module is a wrapper around the `HashedLeanIMT` struct with some utility methods.
4//!
5//! Leaves and nodes are the same size, 32 bytes.
6
7use crate::error::SemaphoreError;
8use ark_ed_on_bn254::Fq;
9use ark_ff::{BigInteger, PrimeField};
10use lean_imt::hashed_tree::{HashedLeanIMT, LeanIMTHasher};
11use light_poseidon::{Poseidon, PoseidonHasher};
12
13/// Size of nodes and leaves in bytes
14pub const ELEMENT_SIZE: usize = 32;
15/// Empty element
16pub const EMPTY_ELEMENT: Element = [0u8; ELEMENT_SIZE];
17
18/// Element type alias
19pub type Element = [u8; ELEMENT_SIZE];
20
21/// Merkle proof alias
22pub type MerkleProof = lean_imt::lean_imt::MerkleProof<ELEMENT_SIZE>;
23
24/// Poseidon LeanIMT hasher
25#[derive(Debug, Default, Clone, PartialEq, Eq)]
26pub struct PoseidonHash;
27
28impl LeanIMTHasher<ELEMENT_SIZE> for PoseidonHash {
29    fn hash(input: &[u8]) -> [u8; ELEMENT_SIZE] {
30        let hash = Poseidon::<Fq>::new_circom(2)
31            .expect("Failed to initialize Poseidon")
32            .hash(&[
33                Fq::from_le_bytes_mod_order(&input[..ELEMENT_SIZE]),
34                Fq::from_le_bytes_mod_order(&input[ELEMENT_SIZE..]),
35            ])
36            .expect("Poseidon hash failed");
37
38        let mut hash_bytes = [0u8; ELEMENT_SIZE];
39        hash_bytes.copy_from_slice(&hash.into_bigint().to_bytes_le());
40
41        hash_bytes
42    }
43}
44
45#[derive(Debug, Default, Clone, PartialEq, Eq)]
46pub struct Group {
47    /// Hashed LeanIMT
48    pub tree: HashedLeanIMT<ELEMENT_SIZE, PoseidonHash>,
49}
50
51impl Group {
52    /// Creates a new instance of the Group with optional initial members
53    pub fn new(members: &[Element]) -> Result<Self, SemaphoreError> {
54        if members.is_empty() {
55            return Ok(Group {
56                tree: HashedLeanIMT::<ELEMENT_SIZE, PoseidonHash>::new(&[], PoseidonHash)?,
57            });
58        }
59
60        for &member in members {
61            if member == EMPTY_ELEMENT {
62                return Err(SemaphoreError::EmptyLeaf);
63            }
64        }
65
66        Ok(Group {
67            tree: HashedLeanIMT::<ELEMENT_SIZE, PoseidonHash>::new(members, PoseidonHash)?,
68        })
69    }
70
71    /// Returns the root hash of the tree, or None if the tree is empty
72    pub fn root(&self) -> Option<Element> {
73        self.tree.root()
74    }
75
76    /// Returns the depth of the tree
77    pub fn depth(&self) -> usize {
78        self.tree.depth()
79    }
80
81    /// Returns the size of the tree (number of leaves)
82    pub fn size(&self) -> usize {
83        self.tree.size()
84    }
85
86    /// Returns the group members
87    pub fn members(&self) -> Vec<Element> {
88        self.tree
89            .leaves()
90            .iter()
91            .map(|v| v.as_slice().try_into().unwrap())
92            .collect()
93    }
94
95    /// Returns the index of a member if it exists
96    pub fn index_of(&self, member: Element) -> Option<usize> {
97        self.tree.index_of(&member)
98    }
99
100    /// Adds a new member to the group
101    pub fn add_member(&mut self, member: Element) -> Result<(), SemaphoreError> {
102        if member == EMPTY_ELEMENT {
103            return Err(SemaphoreError::EmptyLeaf);
104        }
105
106        self.tree.insert(&member);
107        Ok(())
108    }
109
110    /// Adds a set of members to the group
111    pub fn add_members(&mut self, members: &[Element]) -> Result<(), SemaphoreError> {
112        for &member in members {
113            if member == EMPTY_ELEMENT {
114                return Err(SemaphoreError::EmptyLeaf);
115            }
116        }
117
118        self.tree.insert_many(members)?;
119        Ok(())
120    }
121
122    /// Updates a group member
123    pub fn update_member(&mut self, index: usize, member: Element) -> Result<(), SemaphoreError> {
124        if self.members()[index] == EMPTY_ELEMENT {
125            return Err(SemaphoreError::RemovedMember);
126        }
127
128        self.tree.update(index, &member)?;
129        Ok(())
130    }
131
132    /// Removes a member from the group
133    pub fn remove_member(&mut self, index: usize) -> Result<(), SemaphoreError> {
134        if self.members()[index] == EMPTY_ELEMENT {
135            return Err(SemaphoreError::AlreadyRemovedMember);
136        }
137
138        self.tree.update(index, &EMPTY_ELEMENT)?;
139        Ok(())
140    }
141
142    /// Creates a proof of membership for a member
143    pub fn generate_proof(&self, index: usize) -> Result<MerkleProof, SemaphoreError> {
144        self.tree
145            .generate_proof(index)
146            .map_err(SemaphoreError::LeanIMTError)
147    }
148
149    /// Verifies a proof of membership for a member
150    pub fn verify_proof(proof: &MerkleProof) -> bool {
151        HashedLeanIMT::<ELEMENT_SIZE, PoseidonHash>::verify_proof(proof)
152    }
153}
154
155#[cfg(feature = "serde")]
156impl Group {
157    /// Exports the LeanIMT tree to a JSON.
158    pub fn export(&self) -> Result<String, SemaphoreError> {
159        serde_json::to_string(&self.tree.tree())
160            .map_err(|e| SemaphoreError::SerializationError(e.to_string()))
161    }
162
163    /// Imports a Group from a JSON string representing a LeanIMT tree.
164    pub fn import(json: &str) -> Result<Self, SemaphoreError> {
165        let lean_imt_tree: lean_imt::lean_imt::LeanIMT<ELEMENT_SIZE> =
166            serde_json::from_str(json)
167                .map_err(|e| SemaphoreError::SerializationError(e.to_string()))?;
168
169        Ok(Group {
170            tree: HashedLeanIMT::new_from_tree(lean_imt_tree, PoseidonHash),
171        })
172    }
173}
174
175/// Converts a byte array to an element
176pub fn bytes_to_element(bytes: &[u8]) -> Result<Element, SemaphoreError> {
177    if bytes.len() > ELEMENT_SIZE {
178        return Err(SemaphoreError::InputSizeExceeded(bytes.len()));
179    }
180
181    let mut element = EMPTY_ELEMENT;
182    element[..bytes.len()].copy_from_slice(bytes);
183
184    Ok(element)
185}
186
187/// Converts a scalar to an element
188pub fn fq_to_element(fq: &Fq) -> Element {
189    let mut element = EMPTY_ELEMENT;
190    let bytes = fq.into_bigint().to_bytes_le();
191    element[..bytes.len()].copy_from_slice(&bytes);
192    element
193}
194
195/// Converts an element to a scalar
196pub fn element_to_fq(element: &Element) -> Fq {
197    Fq::from_le_bytes_mod_order(element)
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn test_conversions() {
206        let test_bytes = [
207            59, 227, 30, 252, 212, 244, 251, 255, 228, 174, 31, 212, 161, 61, 184, 169, 200, 50, 7,
208            84, 65, 96,
209        ];
210        let element = bytes_to_element(&test_bytes).unwrap();
211        let fq = element_to_fq(&element);
212        let element_back = fq_to_element(&fq);
213
214        assert_eq!(element, element_back);
215        assert_eq!(fq, Fq::from_le_bytes_mod_order(&test_bytes));
216        assert_eq!(
217            bytes_to_element(&[0; 33]),
218            Err(SemaphoreError::InputSizeExceeded(33))
219        );
220    }
221
222    #[test]
223    fn test_create_empty_group() {
224        let group = Group::default();
225
226        assert_eq!(group.root(), None);
227        assert_eq!(group.depth(), 0);
228        assert_eq!(group.size(), 0);
229    }
230
231    #[test]
232    fn test_create_group_with_members() {
233        let member1 = [1; 32];
234        let member2 = [2; 32];
235        let member3 = [3; 32];
236
237        let group1 = Group::new(&[member1, member2, member3]).unwrap();
238
239        let mut group2 = Group::default();
240        group2.add_member(member1).unwrap();
241        group2.add_member(member2).unwrap();
242        group2.add_member(member3).unwrap();
243
244        assert_eq!(group1.root(), group2.root());
245        assert_eq!(group1.depth(), 2);
246        assert_eq!(group1.size(), 3);
247    }
248
249    #[test]
250    fn test_create_group_with_zero_member() {
251        let member1 = [1; 32];
252        let zero = [0u8; ELEMENT_SIZE];
253
254        let result = Group::new(&[member1, zero]);
255
256        assert!(result.is_err());
257        assert_eq!(result, Err(SemaphoreError::EmptyLeaf));
258    }
259
260    #[test]
261    fn test_add_member() {
262        let mut group = Group::default();
263        let member = [1; 32];
264        group.add_member(member).unwrap();
265
266        assert_eq!(group.size(), 1);
267    }
268
269    #[test]
270    fn test_add_zero_member() {
271        let mut group = Group::default();
272        let zero = [0u8; ELEMENT_SIZE];
273        let result = group.add_member(zero);
274
275        assert!(result.is_err());
276        assert_eq!(result, Err(SemaphoreError::EmptyLeaf));
277    }
278
279    #[test]
280    fn test_add_members() {
281        let mut group = Group::default();
282        let member1 = [1; 32];
283        let member2 = [2; 32];
284
285        group.add_members(&[member1, member2]).unwrap();
286
287        assert_eq!(group.size(), 2);
288    }
289
290    #[test]
291    fn test_add_members_with_zero() {
292        let mut group = Group::default();
293        let member1 = [1; 32];
294        let zero = [0u8; ELEMENT_SIZE];
295
296        let result = group.add_members(&[member1, zero]);
297
298        assert!(result.is_err());
299        assert_eq!(result, Err(SemaphoreError::EmptyLeaf));
300    }
301
302    #[test]
303    fn test_index_of() {
304        let member1 = [1; 32];
305        let member2 = [2; 32];
306        let mut group = Group::default();
307
308        group.add_members(&[member1, member2]).unwrap();
309        let index = group.index_of(member2);
310
311        assert_eq!(index, Some(1));
312    }
313
314    #[test]
315    fn test_update_member() {
316        let member1 = [1; 32];
317        let member2 = [2; 32];
318        let mut group = Group::default();
319
320        group.add_members(&[member1, member2]).unwrap();
321
322        group.update_member(0, member1).unwrap();
323        assert_eq!(group.size(), 2);
324
325        let members = group.members();
326        assert_eq!(members[0], member1);
327    }
328
329    #[test]
330    fn test_update_removed_member() {
331        let member1 = [1; 32];
332        let member2 = [2; 32];
333        let mut group = Group::default();
334
335        group.add_members(&[member1, member2]).unwrap();
336        group.remove_member(0).unwrap();
337
338        let result = group.update_member(0, member1);
339        assert!(result.is_err());
340        assert_eq!(result, Err(SemaphoreError::RemovedMember));
341    }
342
343    #[test]
344    fn test_remove_member() {
345        let member1 = [1; 32];
346        let member2 = [2; 32];
347        let mut group = Group::default();
348
349        group.add_members(&[member1, member2]).unwrap();
350        group.remove_member(0).unwrap();
351
352        let members = group.members();
353        assert_eq!(members[0], [0u8; ELEMENT_SIZE]);
354        assert_eq!(group.size(), 2);
355    }
356
357    #[test]
358    fn test_remove_member_already_removed() {
359        let member1 = [1; 32];
360        let member2 = [2; 32];
361        let mut group = Group::default();
362
363        group.add_members(&[member1, member2]).unwrap();
364        group.remove_member(0).unwrap();
365
366        let result = group.remove_member(0);
367
368        assert!(result.is_err());
369        assert_eq!(result, Err(SemaphoreError::AlreadyRemovedMember));
370    }
371
372    #[test]
373    fn test_generate_merkle_proof() {
374        let member1 = [1; 32];
375        let member2 = [2; 32];
376        let mut group = Group::default();
377
378        group.add_members(&[member1, member2]).unwrap();
379
380        let proof = group.generate_proof(0).unwrap();
381        assert_eq!(proof.leaf, member1);
382    }
383
384    #[test]
385    fn test_verify_proof() {
386        let member1 = [1; 32];
387        let member2 = [2; 32];
388        let mut group = Group::default();
389
390        group.add_members(&[member1, member2]).unwrap();
391
392        let proof_0 = group.generate_proof(0).unwrap();
393        assert_eq!(Group::verify_proof(&proof_0), true);
394
395        let mut proof_1 = group.generate_proof(1).unwrap();
396        assert_eq!(Group::verify_proof(&proof_1), true);
397
398        proof_1.leaf = member1;
399        assert_eq!(Group::verify_proof(&proof_1), false);
400    }
401
402    #[cfg(feature = "serde")]
403    #[test]
404    fn test_export_import() {
405        let member1 = [1; 32];
406        let member2 = [2; 32];
407        let member3 = [3; 32];
408        let group = Group::new(&[member1, member2, member3]).unwrap();
409
410        let json = group.export().unwrap();
411        let imported_group = Group::import(&json).unwrap();
412
413        assert_eq!(group, imported_group);
414    }
415}