use std::sync::Arc;
use async_trait::async_trait;
use thiserror::Error;
use tokio::time::{timeout, Duration};
use tracing::{debug, info, warn};
use crate::{
state_machine::{
events::DictionaryUpdate,
phases::{Handler, Phase, PhaseName, PhaseState, PhaseStateError, Shared, Sum2},
requests::{StateMachineRequest, UpdateRequest},
RequestError,
StateMachine,
},
storage::{CoordinatorStorage, ModelStorage, StorageError},
};
use xaynet_core::{
mask::{Aggregation, MaskObject},
LocalSeedDict,
UpdateParticipantPublicKey,
};
#[derive(Error, Debug)]
pub enum UpdateStateError {
#[error("seed dictionary does not exists")]
NoSeedDict,
#[error("fetching seed dictionary failed: {0}")]
FetchSeedDict(StorageError),
}
#[derive(Debug)]
pub struct Update {
model_agg: Aggregation,
accepted: u64,
rejected: u64,
discarded: u64,
}
#[async_trait]
impl<C, M> Phase<C, M> for PhaseState<Update, C, M>
where
Self: Handler,
C: CoordinatorStorage,
M: ModelStorage,
{
const NAME: PhaseName = PhaseName::Update;
async fn run(&mut self) -> Result<(), PhaseStateError> {
let min_time = self.shared.state.min_update_time;
let max_time = self.shared.state.max_update_time;
debug!(
"in update phase for min {} and max {} seconds",
min_time, max_time,
);
self.process_during(Duration::from_secs(min_time)).await?;
let time_left = max_time - min_time;
timeout(Duration::from_secs(time_left), self.process_until_enough()).await??;
info!(
"in total {} update messages accepted (min {} and max {} required)",
self.private.accepted,
self.shared.state.min_update_count,
self.shared.state.max_update_count,
);
info!(
"in total {} update messages rejected",
self.private.rejected,
);
info!(
"in total {} update messages discarded",
self.private.discarded,
);
let seed_dict = self
.shared
.store
.seed_dict()
.await
.map_err(UpdateStateError::FetchSeedDict)?
.ok_or(UpdateStateError::NoSeedDict)?;
info!("broadcasting the global seed dictionary");
self.shared
.events
.broadcast_seed_dict(DictionaryUpdate::New(Arc::new(seed_dict)));
Ok(())
}
fn next(self) -> Option<StateMachine<C, M>> {
Some(PhaseState::<Sum2, _, _>::new(self.shared, self.private.model_agg).into())
}
}
#[async_trait]
impl<C, M> Handler for PhaseState<Update, C, M>
where
C: CoordinatorStorage,
M: ModelStorage,
{
async fn handle_request(&mut self, req: StateMachineRequest) -> Result<(), RequestError> {
if let StateMachineRequest::Update(UpdateRequest {
participant_pk,
local_seed_dict,
masked_model,
}) = req
{
self.update_seed_dict_and_aggregate_mask(
&participant_pk,
&local_seed_dict,
masked_model,
)
.await
} else {
Err(RequestError::MessageRejected)
}
}
fn has_enough_messages(&self) -> bool {
self.private.accepted >= self.shared.state.min_update_count
}
fn has_overmuch_messages(&self) -> bool {
self.private.accepted >= self.shared.state.max_update_count
}
fn increment_accepted(&mut self) {
self.private.accepted += 1;
debug!(
"{} update messages accepted (min {} and max {} required)",
self.private.accepted,
self.shared.state.min_update_count,
self.shared.state.max_update_count,
);
}
fn increment_rejected(&mut self) {
self.private.rejected += 1;
debug!("{} update messages rejected", self.private.rejected);
}
fn increment_discarded(&mut self) {
self.private.discarded += 1;
debug!("{} update messages discarded", self.private.discarded);
}
}
impl<C, M> PhaseState<Update, C, M>
where
C: CoordinatorStorage,
M: ModelStorage,
{
pub fn new(shared: Shared<C, M>) -> Self {
Self {
private: Update {
model_agg: Aggregation::new(
shared.state.round_params.mask_config,
shared.state.round_params.model_length,
),
accepted: 0,
rejected: 0,
discarded: 0,
},
shared,
}
}
async fn update_seed_dict_and_aggregate_mask(
&mut self,
pk: &UpdateParticipantPublicKey,
local_seed_dict: &LocalSeedDict,
mask_object: MaskObject,
) -> Result<(), RequestError> {
debug!("checking whether the masked model can be aggregated");
self.private
.model_agg
.validate_aggregation(&mask_object)
.map_err(|e| {
warn!("model aggregation error: {}", e);
RequestError::AggregationFailed
})?;
info!("updating the global seed dictionary");
self.add_local_seed_dict(pk, local_seed_dict)
.await
.map_err(|err| {
warn!("invalid local seed dictionary, ignoring update message");
err
})?;
info!("aggregating the masked model and scalar");
self.private.model_agg.aggregate(mask_object);
Ok(())
}
async fn add_local_seed_dict(
&mut self,
pk: &UpdateParticipantPublicKey,
local_seed_dict: &LocalSeedDict,
) -> Result<(), RequestError> {
self.shared
.store
.add_local_seed_dict(pk, local_seed_dict)
.await?
.into_inner()
.map_err(RequestError::from)
}
}
#[cfg(test)]
mod tests {
use serial_test::serial;
use super::*;
use crate::{
state_machine::{
events::Event,
tests::{
builder::StateMachineBuilder,
utils::{self, Participant},
},
},
storage::tests::init_store,
};
use xaynet_core::{
common::{RoundParameters, RoundSeed},
crypto::{ByteObject, EncryptKeyPair},
mask::{FromPrimitives, Model},
SeedDict,
SumDict,
UpdateSeedDict,
};
impl Update {
pub fn aggregation(&self) -> &Aggregation {
&self.model_agg
}
}
#[tokio::test]
#[serial]
pub async fn integration_update_to_sum2() {
utils::enable_logging();
let model_length = 4;
let round_params = RoundParameters {
pk: EncryptKeyPair::generate().public,
sum: 0.5,
update: 1.0,
seed: RoundSeed::generate(),
mask_config: utils::mask_config(),
model_length,
};
let n_updaters = 1;
let n_summers = 1;
let summer = utils::generate_summer(round_params.clone());
let updater = utils::generate_updater(round_params.clone());
let mut frozen_sum_dict = SumDict::new();
frozen_sum_dict.insert(summer.keys.public, summer.ephm_keys.public);
let aggregation = Aggregation::new(utils::mask_config(), model_length);
let mut store = init_store().await;
let (state_machine, request_tx, events) = StateMachineBuilder::new(store.clone())
.with_seed(round_params.seed.clone())
.with_phase(Update {
model_agg: aggregation.clone(),
accepted: 0,
rejected: 0,
discarded: 0,
})
.with_sum_ratio(round_params.sum)
.with_update_ratio(round_params.update)
.with_min_sum_count(n_summers)
.with_max_sum_count(n_summers + 10)
.with_min_update_count(n_updaters)
.with_max_update_count(n_updaters + 10)
.with_min_update_time(1)
.with_max_update_time(2)
.with_mask_config(utils::mask_settings().into())
.build();
store
.add_sum_participant(&summer.keys.public, &summer.ephm_keys.public)
.await
.unwrap();
assert!(state_machine.is_update());
let scalar = 1.0 / (n_updaters as f64 * round_params.update);
let model = Model::from_primitives(vec![0; model_length].into_iter()).unwrap();
let (mask_seed, masked_model) = updater.compute_masked_model(&model, scalar);
let local_seed_dict = Participant::build_seed_dict(&frozen_sum_dict, &mask_seed);
let update_msg =
updater.compose_update_message(masked_model.clone(), local_seed_dict.clone());
let request_fut = async { request_tx.msg(&update_msg).await.unwrap() };
let transition_fut = async { state_machine.next().await.unwrap() };
let (_response, state_machine) = tokio::join!(request_fut, transition_fut);
let PhaseState {
private: sum2_state,
..
} = state_machine.into_sum2_phase_state();
let sum_dict = store.sum_dict().await.unwrap().unwrap();
assert_eq!(sum_dict, frozen_sum_dict);
assert_eq!(
<Aggregation as Into<MaskObject>>::into(sum2_state.aggregation().clone()),
masked_model
);
let best_masks = store.best_masks().await.unwrap();
assert!(best_masks.is_none());
assert_eq!(
events.phase_listener().get_latest(),
Event {
round_id: 0,
event: PhaseName::Update,
}
);
let mut global_seed_dict = SeedDict::new();
let mut entry = UpdateSeedDict::new();
let encrypted_mask_seed = local_seed_dict.values().next().unwrap().clone();
entry.insert(updater.keys.public, encrypted_mask_seed);
global_seed_dict.insert(summer.keys.public, entry);
assert_eq!(
events.seed_dict_listener().get_latest(),
Event {
round_id: 0,
event: DictionaryUpdate::New(Arc::new(global_seed_dict)),
}
);
}
}