1use super::{Result, ThresholdError};
19use crate::quantum_crypto::types::*;
20use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24
25pub struct FrostSession {
27 pub session_id: [u8; 32],
29
30 pub message: Vec<u8>,
32
33 pub threshold: u16,
35
36 pub commitments: HashMap<ParticipantId, SigningCommitments>,
38
39 pub shares: HashMap<ParticipantId, SigningShare>,
41
42 pub group_public_key: FrostGroupPublicKey,
44
45 pub state: SessionState,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct SigningCommitments {
52 pub hiding: Vec<u8>,
53 pub binding: Vec<u8>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SigningShare {
59 pub share: Vec<u8>,
60}
61
62#[derive(Debug, Clone, PartialEq)]
64pub enum SessionState {
65 CollectingCommitments,
67
68 CollectingShares,
70
71 ReadyToAggregate,
73
74 Completed,
76
77 Failed(String),
79}
80
81pub struct KeyGenerationResult {
83 pub group_public_key: FrostGroupPublicKey,
85
86 pub shares: HashMap<ParticipantId, ParticipantShare>,
88
89 pub commitments: Vec<Vec<u8>>,
91}
92
93#[derive(Clone)]
95pub struct ParticipantShare {
96 pub participant_id: ParticipantId,
97 pub signing_share: Vec<u8>, pub verifying_share: Vec<u8>, }
100
101impl FrostSession {
102 pub fn new(message: Vec<u8>, threshold: u16, group_public_key: FrostGroupPublicKey) -> Self {
104 Self {
105 session_id: rand::random(),
106 message,
107 threshold,
108 commitments: HashMap::new(),
109 shares: HashMap::new(),
110 group_public_key,
111 state: SessionState::CollectingCommitments,
112 }
113 }
114
115 pub fn add_commitments(
117 &mut self,
118 participant_id: ParticipantId,
119 commitments: SigningCommitments,
120 ) -> Result<()> {
121 if self.state != SessionState::CollectingCommitments {
122 return Err(ThresholdError::InvalidShare(
123 "Not in commitment collection phase".to_string(),
124 ));
125 }
126
127 self.commitments.insert(participant_id, commitments);
128
129 if self.commitments.len() >= self.threshold as usize {
131 self.state = SessionState::CollectingShares;
132 }
133
134 Ok(())
135 }
136
137 pub fn add_share(&mut self, participant_id: ParticipantId, share: SigningShare) -> Result<()> {
139 if self.state != SessionState::CollectingShares {
140 return Err(ThresholdError::InvalidShare(
141 "Not in share collection phase".to_string(),
142 ));
143 }
144
145 if !self.commitments.contains_key(&participant_id) {
147 return Err(ThresholdError::InvalidShare(
148 "Participant did not provide commitments".to_string(),
149 ));
150 }
151
152 self.shares.insert(participant_id, share);
153
154 if self.shares.len() >= self.threshold as usize {
156 self.state = SessionState::ReadyToAggregate;
157 }
158
159 Ok(())
160 }
161
162 pub fn aggregate(&mut self) -> Result<FrostSignature> {
164 if self.state != SessionState::ReadyToAggregate {
165 return Err(ThresholdError::AggregationFailed(
166 "Not ready to aggregate".to_string(),
167 ));
168 }
169
170 let mut aggregated = Vec::new();
175 for (participant_id, share) in &self.shares {
176 aggregated.extend_from_slice(&participant_id.0.to_be_bytes());
177 aggregated.extend_from_slice(&share.share);
178 }
179
180 self.state = SessionState::Completed;
181
182 Ok(FrostSignature(aggregated))
183 }
184
185 pub fn is_complete(&self) -> bool {
187 matches!(self.state, SessionState::Completed)
188 }
189
190 pub fn get_progress(&self) -> SessionProgress {
192 SessionProgress {
193 session_id: self.session_id,
194 state: self.state.clone(),
195 commitments_received: self.commitments.len() as u16,
196 shares_received: self.shares.len() as u16,
197 threshold: self.threshold,
198 }
199 }
200}
201
202#[derive(Debug, Clone)]
204pub struct SessionProgress {
205 pub session_id: [u8; 32],
206 pub state: SessionState,
207 pub commitments_received: u16,
208 pub shares_received: u16,
209 pub threshold: u16,
210}
211
212pub struct FrostCoordinator {
214 pub sessions: HashMap<[u8; 32], FrostSession>,
216
217 pub groups: HashMap<GroupId, GroupInfo>,
219}
220
221pub struct GroupInfo {
223 pub group_public_key: FrostGroupPublicKey,
224 pub threshold: u16,
225 pub participants: Vec<ParticipantId>,
226}
227
228impl Default for FrostCoordinator {
229 fn default() -> Self {
230 Self::new()
231 }
232}
233
234impl FrostCoordinator {
235 pub fn new() -> Self {
237 Self {
238 sessions: HashMap::new(),
239 groups: HashMap::new(),
240 }
241 }
242
243 pub fn register_group(
245 &mut self,
246 group_id: GroupId,
247 group_public_key: FrostGroupPublicKey,
248 threshold: u16,
249 participants: Vec<ParticipantId>,
250 ) {
251 self.groups.insert(
252 group_id,
253 GroupInfo {
254 group_public_key,
255 threshold,
256 participants,
257 },
258 );
259 }
260
261 pub fn initiate_signing(&mut self, group_id: &GroupId, message: Vec<u8>) -> Result<[u8; 32]> {
263 let group_info = self.groups.get(group_id).ok_or_else(|| {
264 ThresholdError::GroupOperationFailed("Group not registered".to_string())
265 })?;
266
267 let session = FrostSession::new(
268 message,
269 group_info.threshold,
270 group_info.group_public_key.clone(),
271 );
272
273 let session_id = session.session_id;
274 self.sessions.insert(session_id, session);
275
276 Ok(session_id)
277 }
278
279 pub fn process_commitment(
281 &mut self,
282 session_id: &[u8; 32],
283 participant_id: ParticipantId,
284 commitments: SigningCommitments,
285 ) -> Result<()> {
286 let session = self
287 .sessions
288 .get_mut(session_id)
289 .ok_or_else(|| ThresholdError::InvalidShare("Session not found".to_string()))?;
290
291 session.add_commitments(participant_id, commitments)
292 }
293
294 pub fn process_share(
296 &mut self,
297 session_id: &[u8; 32],
298 participant_id: ParticipantId,
299 share: SigningShare,
300 ) -> Result<()> {
301 let session = self
302 .sessions
303 .get_mut(session_id)
304 .ok_or_else(|| ThresholdError::InvalidShare("Session not found".to_string()))?;
305
306 session.add_share(participant_id, share)
307 }
308
309 pub fn complete_signing(&mut self, session_id: &[u8; 32]) -> Result<FrostSignature> {
311 let session = self
312 .sessions
313 .get_mut(session_id)
314 .ok_or_else(|| ThresholdError::AggregationFailed("Session not found".to_string()))?;
315
316 let signature = session.aggregate()?;
317
318 Ok(signature)
322 }
323
324 pub fn get_session_status(&self, session_id: &[u8; 32]) -> Option<SessionProgress> {
326 self.sessions.get(session_id).map(|s| s.get_progress())
327 }
328
329 pub fn cleanup_old_sessions(&mut self, _max_age: std::time::Duration) {
331 let _now = std::time::SystemTime::now();
332
333 self.sessions.retain(|_, session| {
334 !matches!(
336 session.state,
337 SessionState::Completed | SessionState::Failed(_)
338 )
339 });
340 }
341}
342
343pub async fn generate_key_shares(threshold: u16, participants: u16) -> Result<KeyGenerationResult> {
345 if threshold > participants {
346 return Err(ThresholdError::InvalidParameters(
347 "Threshold cannot exceed participants".to_string(),
348 ));
349 }
350
351 if threshold == 0 {
352 return Err(ThresholdError::InvalidParameters(
353 "Threshold must be at least 1".to_string(),
354 ));
355 }
356
357 let mut participant_shares = HashMap::new();
359
360 for i in 0..participants {
361 let participant_id = ParticipantId(i);
362 participant_shares.insert(
363 participant_id.clone(),
364 ParticipantShare {
365 participant_id: participant_id.clone(),
366 signing_share: vec![i as u8; 32], verifying_share: vec![i as u8; 32], },
369 );
370 }
371
372 let group_public_key = FrostGroupPublicKey(vec![0; 32]);
374
375 Ok(KeyGenerationResult {
376 group_public_key,
377 shares: participant_shares,
378 commitments: vec![], })
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385
386 #[test]
387 fn test_frost_session_lifecycle() {
388 let message = b"Test message".to_vec();
389 let group_key = FrostGroupPublicKey(vec![0; 32]);
390 let mut session = FrostSession::new(message, 2, group_key);
391
392 assert_eq!(session.state, SessionState::CollectingCommitments);
394
395 session
396 .add_commitments(
397 ParticipantId(1),
398 SigningCommitments {
399 hiding: vec![1; 32],
400 binding: vec![2; 32],
401 },
402 )
403 .unwrap();
404
405 session
406 .add_commitments(
407 ParticipantId(2),
408 SigningCommitments {
409 hiding: vec![3; 32],
410 binding: vec![4; 32],
411 },
412 )
413 .unwrap();
414
415 assert_eq!(session.state, SessionState::CollectingShares);
417
418 session
420 .add_share(ParticipantId(1), SigningShare { share: vec![5; 32] })
421 .unwrap();
422
423 session
424 .add_share(ParticipantId(2), SigningShare { share: vec![6; 32] })
425 .unwrap();
426
427 assert_eq!(session.state, SessionState::ReadyToAggregate);
429
430 let _signature = session.aggregate().unwrap();
432 assert!(session.is_complete());
433 }
434
435 #[tokio::test]
436 async fn test_key_generation() {
437 let result = generate_key_shares(2, 3).await.unwrap();
438
439 assert_eq!(result.shares.len(), 3);
440 assert!(!result.group_public_key.0.is_empty());
441 }
442
443 #[test]
444 fn test_coordinator() {
445 let mut coordinator = FrostCoordinator::new();
446 let group_id = GroupId([1; 32]);
447 let group_key = FrostGroupPublicKey(vec![0; 32]);
448
449 coordinator.register_group(
451 group_id.clone(),
452 group_key,
453 2,
454 vec![ParticipantId(1), ParticipantId(2), ParticipantId(3)],
455 );
456
457 let message = b"Test message".to_vec();
459 let session_id = coordinator.initiate_signing(&group_id, message).unwrap();
460
461 let status = coordinator.get_session_status(&session_id).unwrap();
463 assert_eq!(status.threshold, 2);
464 assert_eq!(status.commitments_received, 0);
465 }
466}