use crate::{
mask::{masking::Aggregation, object::MaskObject},
state_machine::{
coordinator::{CoordinatorState, MaskDict},
phases::{
reject_request,
Handler,
Phase,
PhaseName,
PhaseState,
Purge,
StateError,
Unmask,
},
requests::{Request, RequestReceiver, Sum2Request, Sum2Response},
StateMachine,
},
PetError,
SumDict,
SumParticipantPublicKey,
};
use tokio::{
sync::oneshot,
time::{timeout, Duration},
};
#[derive(Debug)]
pub struct Sum2 {
sum_dict: SumDict,
aggregation: Aggregation,
mask_dict: MaskDict,
}
#[cfg(test)]
impl Sum2 {
pub fn sum_dict(&self) -> &SumDict {
&self.sum_dict
}
pub fn aggregation(&self) -> &Aggregation {
&self.aggregation
}
pub fn mask_dict(&self) -> &MaskDict {
&self.mask_dict
}
}
#[async_trait]
impl<R> Phase<R> for PhaseState<R, Sum2>
where
Self: Purge<R> + Handler<R>,
R: Send,
{
const NAME: PhaseName = PhaseName::Sum2;
async fn run(&mut self) -> Result<(), StateError> {
info!("starting sum2 phase");
info!("broadcasting sum2 phase event");
self.coordinator_state.events.broadcast_phase(
self.coordinator_state.round_params.seed.clone(),
PhaseName::Sum2,
);
let min_time = self.coordinator_state.min_sum_time;
debug!("in sum2 phase for a minimum of {} seconds", min_time);
self.process_during(Duration::from_secs(min_time)).await?;
let time_left = self.coordinator_state.max_sum_time - min_time;
timeout(Duration::from_secs(time_left), self.process_until_enough()).await??;
info!(
"{} sum2 messages handled (min {} required)",
self.mask_count(),
self.coordinator_state.min_sum_count
);
Ok(())
}
fn next(self) -> Option<StateMachine<R>> {
Some(
PhaseState::<R, Unmask>::new(
self.coordinator_state,
self.request_rx,
self.inner.aggregation,
self.inner.mask_dict,
)
.into(),
)
}
}
impl<R> PhaseState<R, Sum2>
where
Self: Handler<R> + Phase<R> + Purge<R>,
{
async fn process_until_enough(&mut self) -> Result<(), StateError> {
while !self.has_enough_sum2s() {
debug!(
"{} sum2 messages handled (min {} required)",
self.mask_count(),
self.coordinator_state.min_sum_count
);
self.process_single().await?;
}
Ok(())
}
}
impl<R> Handler<Request> for PhaseState<R, Sum2> {
fn handle_request(&mut self, req: Request) {
match req {
Request::Sum2((sum2_req, response_tx)) => self.handle_sum2(sum2_req, response_tx),
_ => reject_request(req),
}
}
}
impl<R> PhaseState<R, Sum2> {
pub fn new(
coordinator_state: CoordinatorState,
request_rx: RequestReceiver<R>,
sum_dict: SumDict,
aggregation: Aggregation,
) -> Self {
info!("state transition");
Self {
inner: Sum2 {
sum_dict,
aggregation,
mask_dict: MaskDict::new(),
},
coordinator_state,
request_rx,
}
}
fn handle_sum2(&mut self, req: Sum2Request, response_tx: oneshot::Sender<Sum2Response>) {
let Sum2Request {
participant_pk,
mask,
} = req;
let _ = response_tx.send(self.add_mask(&participant_pk, mask));
}
fn add_mask(&mut self, pk: &SumParticipantPublicKey, mask: MaskObject) -> Result<(), PetError> {
if self.inner.sum_dict.remove(pk).is_none() {
return Err(PetError::InvalidMessage);
}
if let Some(count) = self.inner.mask_dict.get_mut(&mask) {
*count += 1;
} else {
self.inner.mask_dict.insert(mask, 1);
}
Ok(())
}
fn mask_count(&self) -> usize {
self.inner.mask_dict.values().sum()
}
fn has_enough_sum2s(&self) -> bool {
self.mask_count() >= self.coordinator_state.min_sum_count
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{
crypto::{ByteObject, EncryptKeyPair},
mask::{FromPrimitives, Model},
state_machine::{
coordinator::RoundSeed,
events::Event,
tests::{
builder::StateMachineBuilder,
utils::{generate_summer, generate_updater, mask_settings},
},
},
SumDict,
};
#[tokio::test]
pub async fn sum2_to_unmask() {
let n_updaters = 1;
let n_summers = 1;
let seed = RoundSeed::generate();
let sum_ratio = 0.5;
let update_ratio = 1.0;
let coord_keys = EncryptKeyPair::generate();
let model_size = 4;
let mut summer = generate_summer(&seed, sum_ratio, update_ratio);
let ephm_pk = summer.compose_sum_message(&coord_keys.public).ephm_pk();
let mut sum_dict = SumDict::new();
sum_dict.insert(summer.pk, ephm_pk);
let updater = generate_updater(&seed, sum_ratio, update_ratio);
let scalar = 1.0 / (n_updaters as f64 * update_ratio);
let model = Model::from_primitives(vec![0; model_size].into_iter()).unwrap();
let msg =
updater.compose_update_message(coord_keys.public, &sum_dict, scalar, model.clone());
let masked_model = msg.masked_model();
let local_seed_dict = msg.local_seed_dict();
let mut aggregation = Aggregation::new(mask_settings().into(), model_size);
aggregation.aggregate(masked_model.clone());
let sum2 = Sum2 {
sum_dict,
aggregation,
mask_dict: MaskDict::new(),
};
let (state_machine, request_tx, events) = StateMachineBuilder::new()
.with_seed(seed.clone())
.with_phase(sum2)
.with_sum_ratio(sum_ratio)
.with_update_ratio(update_ratio)
.with_min_sum(n_summers)
.with_min_update(n_updaters)
.with_expected_participants(n_updaters + n_summers)
.with_mask_config(mask_settings().into())
.build();
assert!(state_machine.is_sum2());
let msg = summer
.compose_sum2_message(coord_keys.public, &local_seed_dict, masked_model.data.len())
.unwrap();
let req = async { request_tx.clone().sum2(&msg).await.unwrap() };
let transition = async { state_machine.next().await.unwrap() };
let ((), state_machine) = tokio::join!(req, transition);
assert!(state_machine.is_unmask());
let PhaseState {
inner: unmask_state,
..
} = state_machine.into_unmask_phase_state();
assert_eq!(unmask_state.mask_dict().len(), 1);
let (mask, count) = unmask_state.mask_dict().iter().next().unwrap().clone();
assert_eq!(*count, 1);
let unmasked_model = unmask_state
.aggregation()
.unwrap()
.clone()
.unmask(mask.clone());
assert_eq!(unmasked_model, model);
assert_eq!(
events.phase_listener().get_latest(),
Event {
round_id: seed.clone(),
event: PhaseName::Sum2,
}
);
}
}