vsss_rs/
numbering.rs

1use core::fmt::Display;
2use core::{
3    fmt::{self, Debug, Formatter},
4    marker::PhantomData,
5    num::NonZeroUsize,
6};
7use rand_core::{CryptoRng, RngCore};
8use sha3::digest::ExtendableOutput;
9use sha3::{
10    digest::{Update, XofReader},
11    Shake256,
12};
13
14use crate::{Error, ShareIdentifier, VsssResult};
15
16/// The types of participant number generators
17#[derive(Debug, Clone)]
18pub enum ParticipantIdGeneratorType<'a, I: ShareIdentifier> {
19    /// Generate participant numbers sequentially beginning at `start` and incrementing by `increment`
20    /// until `count` is reached then this generator stops.
21    Sequential {
22        /// The starting identifier
23        start: I,
24        /// The amount to increment by each time a new id is needed
25        increment: I,
26        /// The total number of identifiers to generate
27        count: usize,
28    },
29    /// Generate participant numbers randomly using the provided `seed`
30    /// until `count` is reached then this generator stops.
31    Random {
32        /// The seed to use for the random number generator
33        seed: [u8; 32],
34        /// The total number of identifiers to generate
35        count: usize,
36    },
37    /// Use the provided list of identifiers
38    List {
39        /// The list of identifiers to use. Once all have been used the generator will stop
40        list: &'a [I],
41    },
42}
43
44impl<'a, I: ShareIdentifier + Copy> Copy for ParticipantIdGeneratorType<'a, I> {}
45
46impl<I: ShareIdentifier + Display> Display for ParticipantIdGeneratorType<'_, I> {
47    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
48        match self {
49            Self::Sequential {
50                start,
51                increment,
52                count,
53            } => write!(
54                f,
55                "Sequential {{ start: {}, increment: {}, count: {} }}",
56                start, increment, count
57            ),
58            Self::Random { seed, count } => {
59                write!(f, "Random {{ seed: ")?;
60                for &b in seed {
61                    write!(f, "{:02x}", b)?;
62                }
63                write!(f, ", count: {} }}", count)
64            }
65            Self::List { list } => {
66                write!(f, "List {{ list: ")?;
67                for id in list.iter() {
68                    write!(f, "{}, ", id)?;
69                }
70                write!(f, "}}")
71            }
72        }
73    }
74}
75
76impl<I: ShareIdentifier> Default for ParticipantIdGeneratorType<'_, I> {
77    fn default() -> Self {
78        Self::Sequential {
79            start: I::one(),
80            increment: I::one(),
81            count: u16::MAX as usize,
82        }
83    }
84}
85
86#[cfg(any(feature = "alloc", feature = "std"))]
87impl<'a, I: ShareIdentifier> From<&'a crate::Vec<I>> for ParticipantIdGeneratorType<'a, I> {
88    fn from(list: &'a crate::Vec<I>) -> Self {
89        Self::List { list }
90    }
91}
92
93impl<'a, I: ShareIdentifier> From<&'a [I]> for ParticipantIdGeneratorType<'a, I> {
94    fn from(list: &'a [I]) -> Self {
95        Self::List { list }
96    }
97}
98
99impl<'a, I: ShareIdentifier> ParticipantIdGeneratorType<'a, I> {
100    /// Create a new sequential participant number generator
101    pub fn sequential(start: Option<I>, increment: Option<I>, count: NonZeroUsize) -> Self {
102        Self::Sequential {
103            start: start.unwrap_or_else(I::one),
104            increment: increment.unwrap_or_else(I::one),
105            count: count.get(),
106        }
107    }
108
109    /// Create a new random participant number generator
110    pub fn random(seed: [u8; 32], count: NonZeroUsize) -> Self {
111        Self::Random {
112            seed,
113            count: count.get(),
114        }
115    }
116
117    /// Create a new list participant number generator
118    pub fn list(list: &'a [I]) -> Self {
119        Self::List { list }
120    }
121
122    pub(crate) fn try_into_generator(&self) -> VsssResult<ParticipantIdGeneratorState<'a, I>> {
123        match self {
124            Self::Sequential {
125                start,
126                increment,
127                count,
128            } => {
129                if *count == 0 {
130                    return Err(Error::InvalidGenerator(
131                        "The count must be greater than zero",
132                    ));
133                }
134                Ok(ParticipantIdGeneratorState::Sequential(
135                    SequentialParticipantNumberGenerator {
136                        start: start.clone(),
137                        increment: increment.clone(),
138                        index: 0,
139                        count: *count,
140                    },
141                ))
142            }
143            Self::Random { seed, count } => {
144                if *count == 0 {
145                    return Err(Error::InvalidGenerator(
146                        "The count must be greater than zero",
147                    ));
148                }
149                Ok(ParticipantIdGeneratorState::Random(
150                    RandomParticipantNumberGenerator {
151                        dst: *seed,
152                        index: 0,
153                        count: *count,
154                        _markers: PhantomData,
155                    },
156                ))
157            }
158            Self::List { list } => Ok(ParticipantIdGeneratorState::List(
159                ListParticipantNumberGenerator { list, index: 0 },
160            )),
161        }
162    }
163}
164
165/// A collection of participant number generators
166#[derive(Debug, Clone)]
167pub struct ParticipantIdGeneratorCollection<'a, 'b, I: ShareIdentifier> {
168    /// The collection of participant id generators
169    pub generators: &'a [ParticipantIdGeneratorType<'b, I>],
170}
171
172impl<'a, 'b, I: ShareIdentifier + Copy> Copy for ParticipantIdGeneratorCollection<'a, 'b, I> {}
173
174impl<'a, 'b, I: ShareIdentifier> From<&'a [ParticipantIdGeneratorType<'b, I>]>
175    for ParticipantIdGeneratorCollection<'a, 'b, I>
176{
177    fn from(generators: &'a [ParticipantIdGeneratorType<'b, I>]) -> Self {
178        Self { generators }
179    }
180}
181
182impl<'a, 'b, I: ShareIdentifier, const L: usize> From<&'a [ParticipantIdGeneratorType<'b, I>; L]>
183    for ParticipantIdGeneratorCollection<'a, 'b, I>
184{
185    fn from(generators: &'a [ParticipantIdGeneratorType<'b, I>; L]) -> Self {
186        Self { generators }
187    }
188}
189
190#[cfg(any(feature = "alloc", feature = "std"))]
191impl<'a, 'b, I: ShareIdentifier> From<&'a crate::Vec<ParticipantIdGeneratorType<'b, I>>>
192    for ParticipantIdGeneratorCollection<'a, 'b, I>
193{
194    fn from(generators: &'a crate::Vec<ParticipantIdGeneratorType<'b, I>>) -> Self {
195        Self {
196            generators: generators.as_slice(),
197        }
198    }
199}
200
201impl<'a, 'b, I: ShareIdentifier> ParticipantIdGeneratorCollection<'a, 'b, I> {
202    /// Returns an iterator that generates participant identifiers.
203    ///
204    /// The iterator will halt if an internal error occurs or an identifier
205    /// is generated that is the zero element.
206    pub fn iter(&self) -> impl Iterator<Item = I> + '_ {
207        let mut participant_id_iter = self.generators.iter().map(|g| g.try_into_generator());
208        let mut current: Option<ParticipantIdGeneratorState<'a, I>> = None;
209        core::iter::from_fn(move || {
210            loop {
211                if let Some(ref mut generator) = current {
212                    match generator.next() {
213                        Some(id) => {
214                            if id.is_zero().into() {
215                                current = None; // Move to next generator
216                                continue;
217                            }
218                            return Some(id);
219                        }
220                        None => {
221                            current = None; // Current generator exhausted, move to next
222                        }
223                    }
224                }
225
226                // If we're here, we need a new generator
227                match participant_id_iter.next() {
228                    Some(Ok(new_generator)) => {
229                        current = Some(new_generator);
230                        // Continue to next iteration to start using this generator
231                    }
232                    Some(Err(_)) => return None, // Errored generator
233                    None => return None,         // All generators exhausted
234                }
235            }
236        })
237    }
238}
239
240pub(crate) enum ParticipantIdGeneratorState<'a, I: ShareIdentifier> {
241    Sequential(SequentialParticipantNumberGenerator<I>),
242    Random(RandomParticipantNumberGenerator<I>),
243    List(ListParticipantNumberGenerator<'a, I>),
244}
245
246impl<'a, I: ShareIdentifier> Iterator for ParticipantIdGeneratorState<'a, I> {
247    type Item = I;
248
249    fn next(&mut self) -> Option<Self::Item> {
250        match self {
251            Self::Sequential(gen) => gen.next(),
252            Self::Random(gen) => gen.next(),
253            Self::List(gen) => gen.next(),
254        }
255    }
256}
257
258#[derive(Debug)]
259/// A generator that can create any number of secret shares
260pub(crate) struct SequentialParticipantNumberGenerator<I: ShareIdentifier> {
261    start: I,
262    increment: I,
263    index: usize,
264    count: usize,
265}
266
267impl<I: ShareIdentifier> Iterator for SequentialParticipantNumberGenerator<I> {
268    type Item = I;
269
270    fn next(&mut self) -> Option<Self::Item> {
271        if self.index >= self.count {
272            return None;
273        }
274        let value = self.start.clone();
275        self.start.inc(&self.increment);
276        self.index += 1;
277        Some(value)
278    }
279}
280
281/// A generator that creates random participant identifiers
282#[derive(Debug)]
283pub(crate) struct RandomParticipantNumberGenerator<I: ShareIdentifier> {
284    /// Domain separation tag
285    dst: [u8; 32],
286    index: usize,
287    count: usize,
288    _markers: PhantomData<I>,
289}
290
291impl<I: ShareIdentifier> Iterator for RandomParticipantNumberGenerator<I> {
292    type Item = I;
293
294    fn next(&mut self) -> Option<Self::Item> {
295        if self.index >= self.count {
296            return None;
297        }
298        self.index += 1;
299        Some(I::random(self.get_rng(self.index)))
300    }
301}
302
303impl<I: ShareIdentifier> RandomParticipantNumberGenerator<I> {
304    fn get_rng(&self, index: usize) -> XofRng {
305        let mut hasher = Shake256::default();
306        hasher.update(&self.dst);
307        hasher.update(&index.to_be_bytes());
308        hasher.update(&self.count.to_be_bytes());
309        XofRng(hasher.finalize_xof())
310    }
311}
312
313/// A generator that creates participant identifiers from a known list
314#[derive(Debug)]
315pub(crate) struct ListParticipantNumberGenerator<'a, I: ShareIdentifier> {
316    list: &'a [I],
317    index: usize,
318}
319
320impl<'a, I: ShareIdentifier> Iterator for ListParticipantNumberGenerator<'a, I> {
321    type Item = I;
322
323    fn next(&mut self) -> Option<Self::Item> {
324        if self.index >= self.list.len() {
325            return None;
326        }
327        let index = self.index;
328        self.index += 1;
329        Some(self.list[index].clone())
330    }
331}
332
333#[derive(Clone)]
334#[repr(transparent)]
335struct XofRng(<Shake256 as ExtendableOutput>::Reader);
336
337impl RngCore for XofRng {
338    fn next_u32(&mut self) -> u32 {
339        let mut buf = [0u8; 4];
340        self.0.read(&mut buf);
341        u32::from_be_bytes(buf)
342    }
343
344    fn next_u64(&mut self) -> u64 {
345        let mut buf = [0u8; 8];
346        self.0.read(&mut buf);
347        u64::from_be_bytes(buf)
348    }
349
350    fn fill_bytes(&mut self, dest: &mut [u8]) {
351        self.0.read(dest);
352    }
353
354    fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> {
355        self.0.read(dest);
356        Ok(())
357    }
358}
359
360impl CryptoRng for XofRng {}
361
362impl Debug for XofRng {
363    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
364        write!(f, "XofRng")
365    }
366}
367
368#[cfg(all(test, any(feature = "alloc", feature = "std")))]
369mod tests {
370    use super::*;
371    use crate::*;
372    use elliptic_curve::PrimeField;
373    use k256::{FieldBytes, Scalar};
374    use rand_core::SeedableRng;
375
376    #[cfg(any(feature = "alloc", feature = "std"))]
377    #[test]
378    fn test_sequential_participant_number_generator() {
379        let gen = SequentialParticipantNumberGenerator::<IdentifierPrimeField<Scalar>> {
380            start: IdentifierPrimeField::<Scalar>::ONE,
381            increment: IdentifierPrimeField::<Scalar>::ONE,
382            index: 0,
383            count: 5,
384        };
385        let list: Vec<_> = gen.collect();
386        assert_eq!(list.len(), 5);
387        assert_eq!(list[0], IdentifierPrimeField::from(Scalar::from(1u64)));
388        assert_eq!(list[1], IdentifierPrimeField::from(Scalar::from(2u64)));
389        assert_eq!(list[2], IdentifierPrimeField::from(Scalar::from(3u64)));
390        assert_eq!(list[3], IdentifierPrimeField::from(Scalar::from(4u64)));
391        assert_eq!(list[4], IdentifierPrimeField::from(Scalar::from(5u64)));
392    }
393
394    #[cfg(any(feature = "alloc", feature = "std"))]
395    #[test]
396    fn test_random_participant_number_generator() {
397        let mut rng = rand_chacha::ChaCha8Rng::from_seed([1u8; 32]);
398        let mut dst = [0u8; 32];
399        rng.fill_bytes(&mut dst);
400        let gen = RandomParticipantNumberGenerator::<IdentifierPrimeField<Scalar>> {
401            dst,
402            index: 0,
403            count: 5,
404            _markers: PhantomData,
405        };
406        let list: Vec<_> = gen.collect();
407        assert_eq!(list.len(), 5);
408        let mut repr = FieldBytes::default();
409        for (i, s) in [
410            "134de46908fd0867a9c14ed96e90cd34be47e2b052ca266499687adae4cfe445",
411            "5b182d31afa277bcfb5d6316c31e231004d29f2c99e4dec0c384d7a46439c8ca",
412            "cb15c36dfe7b15c253e3f9fde1fd9ccfbd75839ff6dccca49700cb831dc5802e",
413            "bb3a92d716f6a8d94d82295fd120b23d42ec8543a405ecd82e519ab0fe4ef965",
414            "a0fff4c9e992f0d1acc8bc90fe6ae31dee280a0175a028a6333dde56de2121ec",
415        ]
416        .iter()
417        .enumerate()
418        {
419            repr.copy_from_slice(&hex::decode(s).unwrap());
420            assert_eq!(
421                list[i],
422                IdentifierPrimeField::from(Scalar::from_repr(repr).unwrap())
423            );
424        }
425    }
426
427    #[cfg(any(feature = "alloc", feature = "std"))]
428    #[test]
429    fn test_list_participant_number_generator() {
430        let list = [
431            IdentifierPrimeField::from(Scalar::from(10u64)),
432            IdentifierPrimeField::from(Scalar::from(20u64)),
433            IdentifierPrimeField::from(Scalar::from(30u64)),
434            IdentifierPrimeField::from(Scalar::from(40u64)),
435            IdentifierPrimeField::from(Scalar::from(50u64)),
436        ];
437        let gen = ListParticipantNumberGenerator {
438            list: &list,
439            index: 0,
440        };
441        let list: Vec<_> = gen.collect();
442        assert_eq!(list.len(), 5);
443        assert_eq!(list[0], IdentifierPrimeField::from(Scalar::from(10u64)));
444        assert_eq!(list[1], IdentifierPrimeField::from(Scalar::from(20u64)));
445        assert_eq!(list[2], IdentifierPrimeField::from(Scalar::from(30u64)));
446        assert_eq!(list[3], IdentifierPrimeField::from(Scalar::from(40u64)));
447        assert_eq!(list[4], IdentifierPrimeField::from(Scalar::from(50u64)));
448    }
449
450    #[test]
451    fn test_list_and_sequential_number_generator() {
452        let list = [
453            IdentifierPrimeField::from(Scalar::from(10u64)),
454            IdentifierPrimeField::from(Scalar::from(20u64)),
455            IdentifierPrimeField::from(Scalar::from(30u64)),
456            IdentifierPrimeField::from(Scalar::from(40u64)),
457            IdentifierPrimeField::from(Scalar::from(50u64)),
458        ];
459        let set = [
460            ParticipantIdGeneratorType::list(&list),
461            ParticipantIdGeneratorType::sequential(
462                Some(IdentifierPrimeField::from(Scalar::from(51u64))),
463                Some(IdentifierPrimeField::<Scalar>::ONE),
464                NonZeroUsize::new(5).unwrap(),
465            ),
466        ];
467        let collection = ParticipantIdGeneratorCollection::from(&set[..]);
468
469        let expected = [
470            IdentifierPrimeField::from(Scalar::from(10u64)),
471            IdentifierPrimeField::from(Scalar::from(20u64)),
472            IdentifierPrimeField::from(Scalar::from(30u64)),
473            IdentifierPrimeField::from(Scalar::from(40u64)),
474            IdentifierPrimeField::from(Scalar::from(50u64)),
475            IdentifierPrimeField::from(Scalar::from(51u64)),
476            IdentifierPrimeField::from(Scalar::from(52u64)),
477            IdentifierPrimeField::from(Scalar::from(53u64)),
478            IdentifierPrimeField::from(Scalar::from(54u64)),
479            IdentifierPrimeField::from(Scalar::from(55u64)),
480        ];
481        let mut last_i = 0;
482        for (i, id) in collection.iter().enumerate() {
483            assert_eq!(id, expected[i]);
484            last_i = i;
485        }
486        assert_eq!(last_i, expected.len() - 1);
487    }
488
489    #[test]
490    fn test_list_and_random_number_generator() {
491        let list = [
492            IdentifierPrimeField::from(Scalar::from(10u64)),
493            IdentifierPrimeField::from(Scalar::from(20u64)),
494            IdentifierPrimeField::from(Scalar::from(30u64)),
495            IdentifierPrimeField::from(Scalar::from(40u64)),
496            IdentifierPrimeField::from(Scalar::from(50u64)),
497        ];
498        let mut rng = rand_chacha::ChaCha8Rng::from_seed([1u8; 32]);
499        let mut dst = [0u8; 32];
500        rng.fill_bytes(&mut dst);
501        let set = [
502            ParticipantIdGeneratorType::list(&list),
503            ParticipantIdGeneratorType::random(dst, NonZeroUsize::new(5).unwrap()),
504        ];
505        let collection = ParticipantIdGeneratorCollection::from(&set);
506        let expected = [
507            IdentifierPrimeField::from(Scalar::from(10u64)),
508            IdentifierPrimeField::from(Scalar::from(20u64)),
509            IdentifierPrimeField::from(Scalar::from(30u64)),
510            IdentifierPrimeField::from(Scalar::from(40u64)),
511            IdentifierPrimeField::from(Scalar::from(50u64)),
512            hex::decode("134de46908fd0867a9c14ed96e90cd34be47e2b052ca266499687adae4cfe445")
513                .map(|b| {
514                    IdentifierPrimeField::from(
515                        Scalar::from_repr(FieldBytes::clone_from_slice(&b)).unwrap(),
516                    )
517                })
518                .unwrap(),
519            hex::decode("5b182d31afa277bcfb5d6316c31e231004d29f2c99e4dec0c384d7a46439c8ca")
520                .map(|b| {
521                    IdentifierPrimeField::from(
522                        Scalar::from_repr(FieldBytes::clone_from_slice(&b)).unwrap(),
523                    )
524                })
525                .unwrap(),
526            hex::decode("cb15c36dfe7b15c253e3f9fde1fd9ccfbd75839ff6dccca49700cb831dc5802e")
527                .map(|b| {
528                    IdentifierPrimeField::from(
529                        Scalar::from_repr(FieldBytes::clone_from_slice(&b)).unwrap(),
530                    )
531                })
532                .unwrap(),
533            hex::decode("bb3a92d716f6a8d94d82295fd120b23d42ec8543a405ecd82e519ab0fe4ef965")
534                .map(|b| {
535                    IdentifierPrimeField::from(
536                        Scalar::from_repr(FieldBytes::clone_from_slice(&b)).unwrap(),
537                    )
538                })
539                .unwrap(),
540            hex::decode("a0fff4c9e992f0d1acc8bc90fe6ae31dee280a0175a028a6333dde56de2121ec")
541                .map(|b| {
542                    IdentifierPrimeField::from(
543                        Scalar::from_repr(FieldBytes::clone_from_slice(&b)).unwrap(),
544                    )
545                })
546                .unwrap(),
547        ];
548        let mut last_i = 0;
549        for (i, id) in collection.iter().enumerate() {
550            assert_eq!(id, expected[i]);
551            last_i = i;
552        }
553        assert_eq!(last_i, expected.len() - 1);
554    }
555
556    #[cfg(any(feature = "alloc", feature = "std"))]
557    #[test]
558    fn test_empty_list_and_sequential_number_generator() {
559        let list: [IdentifierPrimeField<Scalar>; 0] = [];
560        let generators = [
561            ParticipantIdGeneratorType::list(&list),
562            ParticipantIdGeneratorType::sequential(None, None, NonZeroUsize::new(5).unwrap()),
563        ];
564        let collection = ParticipantIdGeneratorCollection::from(&generators);
565        let list: Vec<_> = collection.iter().collect();
566        assert_eq!(list.len(), 5);
567        assert_eq!(list[0], IdentifierPrimeField::from(Scalar::from(1u64)));
568        assert_eq!(list[1], IdentifierPrimeField::from(Scalar::from(2u64)));
569        assert_eq!(list[2], IdentifierPrimeField::from(Scalar::from(3u64)));
570        assert_eq!(list[3], IdentifierPrimeField::from(Scalar::from(4u64)));
571        assert_eq!(list[4], IdentifierPrimeField::from(Scalar::from(5u64)));
572    }
573}