1use super::*;
4use crate::quantum_crypto::types::*;
5use std::collections::HashSet;
6
7impl ThresholdGroup {
8 pub fn check_permission(
10 &self,
11 participant_id: &ParticipantId,
12 permission: Permission,
13 ) -> Result<()> {
14 let participant = self.active_participants
15 .iter()
16 .find(|p| &p.participant_id == participant_id)
17 .ok_or_else(|| ThresholdError::ParticipantNotFound(participant_id.clone()))?;
18
19 match (&participant.role, permission) {
20 (ParticipantRole::Leader { permissions }, Permission::AddParticipant) => {
21 if permissions.can_add_participants {
22 Ok(())
23 } else {
24 Err(ThresholdError::Unauthorized("Cannot add participants".to_string()))
25 }
26 }
27 (ParticipantRole::Leader { permissions }, Permission::RemoveParticipant) => {
28 if permissions.can_remove_participants {
29 Ok(())
30 } else {
31 Err(ThresholdError::Unauthorized("Cannot remove participants".to_string()))
32 }
33 }
34 (ParticipantRole::Leader { permissions }, Permission::UpdateThreshold) => {
35 if permissions.can_update_threshold {
36 Ok(())
37 } else {
38 Err(ThresholdError::Unauthorized("Cannot update threshold".to_string()))
39 }
40 }
41 (ParticipantRole::Member { permissions }, Permission::Sign) => {
42 if permissions.can_sign {
43 Ok(())
44 } else {
45 Err(ThresholdError::Unauthorized("Cannot sign".to_string()))
46 }
47 }
48 (ParticipantRole::Observer, _) => {
49 Err(ThresholdError::Unauthorized("Observers have read-only access".to_string()))
50 }
51 _ => Err(ThresholdError::Unauthorized("Permission denied".to_string())),
52 }
53 }
54
55 pub fn get_active_participants(&self) -> Vec<&ParticipantInfo> {
57 self.active_participants
58 .iter()
59 .filter(|p| matches!(p.status, ParticipantStatus::Active))
60 .collect()
61 }
62
63 pub fn active_participant_count(&self) -> u16 {
65 self.get_active_participants().len() as u16
66 }
67
68 pub fn has_threshold_participants(&self) -> bool {
70 self.active_participant_count() >= self.threshold
71 }
72
73 pub fn add_pending_participant(&mut self, participant: ParticipantInfo) -> Result<()> {
75 if self.active_participants.iter().any(|p| p.participant_id == participant.participant_id) {
77 return Err(ThresholdError::InvalidParameters(
78 "Participant already exists".to_string()
79 ));
80 }
81
82 if self.pending_participants.iter().any(|p| p.participant_id == participant.participant_id) {
83 return Err(ThresholdError::InvalidParameters(
84 "Participant already pending".to_string()
85 ));
86 }
87
88 self.pending_participants.push(participant);
89 self.version += 1;
90 self.last_updated = SystemTime::now();
91
92 Ok(())
93 }
94
95 pub fn mark_for_removal(&mut self, participant_id: &ParticipantId) -> Result<()> {
97 let participant = self.active_participants
98 .iter_mut()
99 .find(|p| &p.participant_id == participant_id)
100 .ok_or_else(|| ThresholdError::ParticipantNotFound(participant_id.clone()))?;
101
102 participant.status = ParticipantStatus::PendingRemoval;
103 self.version += 1;
104 self.last_updated = SystemTime::now();
105
106 if self.active_participant_count() < self.threshold {
108 return Err(ThresholdError::InsufficientParticipants {
109 required: self.threshold,
110 available: self.active_participant_count(),
111 });
112 }
113
114 Ok(())
115 }
116
117 pub fn update_participant_role(
119 &mut self,
120 participant_id: &ParticipantId,
121 new_role: ParticipantRole,
122 ) -> Result<()> {
123 let participant = self.active_participants
124 .iter_mut()
125 .find(|p| &p.participant_id == participant_id)
126 .ok_or_else(|| ThresholdError::ParticipantNotFound(participant_id.clone()))?;
127
128 participant.role = new_role;
129 self.version += 1;
130 self.last_updated = SystemTime::now();
131
132 Ok(())
133 }
134
135 pub fn suspend_participant(
137 &mut self,
138 participant_id: &ParticipantId,
139 reason: String,
140 duration: std::time::Duration,
141 ) -> Result<()> {
142 let participant = self.active_participants
143 .iter_mut()
144 .find(|p| &p.participant_id == participant_id)
145 .ok_or_else(|| ThresholdError::ParticipantNotFound(participant_id.clone()))?;
146
147 participant.status = ParticipantStatus::Suspended {
148 reason,
149 until: SystemTime::now() + duration,
150 };
151
152 self.version += 1;
153 self.last_updated = SystemTime::now();
154
155 if self.active_participant_count() < self.threshold {
157 return Err(ThresholdError::InsufficientParticipants {
158 required: self.threshold,
159 available: self.active_participant_count(),
160 });
161 }
162
163 Ok(())
164 }
165
166 pub fn update_threshold(&mut self, new_threshold: u16) -> Result<()> {
168 if new_threshold == 0 {
169 return Err(ThresholdError::InvalidParameters(
170 "Threshold must be at least 1".to_string()
171 ));
172 }
173
174 if new_threshold > self.participants {
175 return Err(ThresholdError::InvalidParameters(
176 "Threshold cannot exceed total participants".to_string()
177 ));
178 }
179
180 if new_threshold > self.active_participant_count() {
181 return Err(ThresholdError::InvalidParameters(
182 "Threshold cannot exceed active participants".to_string()
183 ));
184 }
185
186 self.threshold = new_threshold;
187 self.version += 1;
188 self.last_updated = SystemTime::now();
189
190 Ok(())
191 }
192
193 pub fn get_participants_by_role(&self, role_filter: RoleFilter) -> Vec<&ParticipantInfo> {
195 self.active_participants
196 .iter()
197 .filter(|p| match (&p.role, &role_filter) {
198 (ParticipantRole::Leader { .. }, RoleFilter::Leaders) => true,
199 (ParticipantRole::Member { .. }, RoleFilter::Members) => true,
200 (ParticipantRole::Observer, RoleFilter::Observers) => true,
201 (_, RoleFilter::All) => true,
202 _ => false,
203 })
204 .collect()
205 }
206
207 pub fn get_hierarchy(&self) -> GroupHierarchy {
209 GroupHierarchy {
210 group_id: self.group_id.clone(),
211 parent: self.metadata.parent_group.clone(),
212 name: self.metadata.name.clone(),
213 threshold: self.threshold,
214 participants: self.participants,
215 purpose: self.metadata.purpose.clone(),
216 }
217 }
218
219 pub fn validate(&self) -> Result<()> {
221 if self.threshold == 0 {
223 return Err(ThresholdError::InvalidParameters(
224 "Invalid threshold: must be at least 1".to_string()
225 ));
226 }
227
228 if self.threshold > self.participants {
229 return Err(ThresholdError::InvalidParameters(
230 "Invalid threshold: exceeds total participants".to_string()
231 ));
232 }
233
234 let mut seen_ids = HashSet::new();
236 for participant in &self.active_participants {
237 if !seen_ids.insert(&participant.participant_id) {
238 return Err(ThresholdError::InvalidParameters(
239 format!("Duplicate participant ID: {:?}", participant.participant_id)
240 ));
241 }
242 }
243
244 let has_leader = self.active_participants
246 .iter()
247 .any(|p| matches!(p.role, ParticipantRole::Leader { .. }));
248
249 if !has_leader {
250 return Err(ThresholdError::InvalidParameters(
251 "Group must have at least one leader".to_string()
252 ));
253 }
254
255 Ok(())
256 }
257
258 pub fn add_audit_entry(&mut self, entry: GroupAuditEntry) {
260 self.audit_log.push(entry);
261
262 if self.audit_log.len() > 1000 {
264 self.audit_log.drain(0..100);
265 }
266 }
267}
268
269#[derive(Debug, Clone, Copy, PartialEq)]
271pub enum Permission {
272 AddParticipant,
273 RemoveParticipant,
274 UpdateThreshold,
275 Sign,
276 Vote,
277 CreateSubgroup,
278 AssignRoles,
279}
280
281#[derive(Debug, Clone, PartialEq)]
283pub enum RoleFilter {
284 All,
285 Leaders,
286 Members,
287 Observers,
288}
289
290#[derive(Debug, Clone)]
292pub struct GroupHierarchy {
293 pub group_id: GroupId,
294 pub parent: Option<GroupId>,
295 pub name: String,
296 pub threshold: u16,
297 pub participants: u16,
298 pub purpose: GroupPurpose,
299}
300
301#[derive(Debug, Clone)]
303pub struct GroupStats {
304 pub total_participants: u16,
305 pub active_participants: u16,
306 pub pending_participants: u16,
307 pub suspended_participants: u16,
308 pub leaders: u16,
309 pub members: u16,
310 pub observers: u16,
311 pub total_operations: usize,
312 pub successful_operations: usize,
313 pub failed_operations: usize,
314}
315
316impl ThresholdGroup {
317 pub fn get_stats(&self) -> GroupStats {
319 let mut stats = GroupStats {
320 total_participants: self.participants,
321 active_participants: 0,
322 pending_participants: self.pending_participants.len() as u16,
323 suspended_participants: 0,
324 leaders: 0,
325 members: 0,
326 observers: 0,
327 total_operations: self.audit_log.len(),
328 successful_operations: 0,
329 failed_operations: 0,
330 };
331
332 for participant in &self.active_participants {
333 match &participant.status {
334 ParticipantStatus::Active => stats.active_participants += 1,
335 ParticipantStatus::Suspended { .. } => stats.suspended_participants += 1,
336 _ => {}
337 }
338
339 match &participant.role {
340 ParticipantRole::Leader { .. } => stats.leaders += 1,
341 ParticipantRole::Member { .. } => stats.members += 1,
342 ParticipantRole::Observer => stats.observers += 1,
343 }
344 }
345
346 for entry in &self.audit_log {
347 match &entry.result {
348 OperationResult::Success => stats.successful_operations += 1,
349 OperationResult::Failed(_) => stats.failed_operations += 1,
350 OperationResult::Pending => {}
351 }
352 }
353
354 stats
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 fn create_test_group() -> ThresholdGroup {
363 let participant1 = ParticipantInfo {
364 participant_id: ParticipantId(1),
365 public_key: MlDsaPublicKey(vec![1; 32]),
366 frost_share_commitment: FrostCommitment(vec![1; 32]),
367 role: ParticipantRole::Leader {
368 permissions: LeaderPermissions::default(),
369 },
370 status: ParticipantStatus::Active,
371 joined_at: SystemTime::now(),
372 metadata: HashMap::new(),
373 };
374
375 let participant2 = ParticipantInfo {
376 participant_id: ParticipantId(2),
377 public_key: MlDsaPublicKey(vec![2; 32]),
378 frost_share_commitment: FrostCommitment(vec![2; 32]),
379 role: ParticipantRole::Member {
380 permissions: MemberPermissions::default(),
381 },
382 status: ParticipantStatus::Active,
383 joined_at: SystemTime::now(),
384 metadata: HashMap::new(),
385 };
386
387 ThresholdGroup {
388 group_id: GroupId([0; 32]),
389 threshold: 2,
390 participants: 2,
391 frost_group_key: FrostGroupPublicKey(vec![0; 32]),
392 active_participants: vec![participant1, participant2],
393 pending_participants: vec![],
394 version: 1,
395 metadata: GroupMetadata {
396 name: "Test Group".to_string(),
397 description: "Test group for unit tests".to_string(),
398 purpose: GroupPurpose::MultiSig,
399 parent_group: None,
400 custom_data: HashMap::new(),
401 },
402 audit_log: vec![],
403 created_at: SystemTime::now(),
404 last_updated: SystemTime::now(),
405 }
406 }
407
408 #[test]
409 fn test_permission_checking() {
410 let group = create_test_group();
411
412 assert!(group.check_permission(
414 &ParticipantId(1),
415 Permission::AddParticipant
416 ).is_ok());
417
418 assert!(group.check_permission(
420 &ParticipantId(2),
421 Permission::AddParticipant
422 ).is_err());
423
424 assert!(group.check_permission(
426 &ParticipantId(2),
427 Permission::Sign
428 ).is_ok());
429 }
430
431 #[test]
432 fn test_group_validation() {
433 let mut group = create_test_group();
434
435 assert!(group.validate().is_ok());
437
438 group.threshold = 0;
440 assert!(group.validate().is_err());
441
442 group.threshold = 3; assert!(group.validate().is_err());
444 }
445
446 #[test]
447 fn test_participant_management() {
448 let mut group = create_test_group();
449
450 let new_participant = ParticipantInfo {
452 participant_id: ParticipantId(3),
453 public_key: MlDsaPublicKey(vec![3; 32]),
454 frost_share_commitment: FrostCommitment(vec![3; 32]),
455 role: ParticipantRole::Member {
456 permissions: MemberPermissions::default(),
457 },
458 status: ParticipantStatus::PendingJoin,
459 joined_at: SystemTime::now(),
460 metadata: HashMap::new(),
461 };
462
463 assert!(group.add_pending_participant(new_participant).is_ok());
464 assert_eq!(group.pending_participants.len(), 1);
465
466 let duplicate = ParticipantInfo {
468 participant_id: ParticipantId(1),
469 public_key: MlDsaPublicKey(vec![1; 32]),
470 frost_share_commitment: FrostCommitment(vec![1; 32]),
471 role: ParticipantRole::Member {
472 permissions: MemberPermissions::default(),
473 },
474 status: ParticipantStatus::PendingJoin,
475 joined_at: SystemTime::now(),
476 metadata: HashMap::new(),
477 };
478
479 assert!(group.add_pending_participant(duplicate).is_err());
480 }
481}