#[macro_use]
extern crate async_trait;
#[macro_use]
extern crate serde;
#[macro_use]
extern crate tracing;
use std::time::Duration;
use thiserror::Error;
use tokio::time;
use xaynet_core::{crypto::ByteObject, mask::Model, CoordinatorPublicKey, InitError};
#[doc(hidden)]
pub mod mobile_client;
pub mod api;
mod participant;
pub use participant::{Participant, Task};
#[derive(Clone, Debug)]
pub enum CachedModel {
F32(Vec<f32>),
F64(Vec<f64>),
I32(Vec<i32>),
I64(Vec<i64>),
}
#[derive(Debug, Error)]
pub enum PetError {
#[error("Invalid mask")]
InvalidMask,
#[error("Invalid model")]
InvalidModel,
}
#[derive(Debug, Error)]
pub enum ClientError<E: ::std::error::Error + ::std::fmt::Debug + 'static> {
#[error("failed to initialise participant: {0}")]
ParticipantInitErr(InitError),
#[error("an API request failed: {0}")]
Api(#[from] E),
#[error("error arising from participant")]
ParticipantErr(PetError),
#[error("{0} not ready yet")]
TooEarly(&'static str),
#[error("round outdated")]
RoundOutdated,
}
pub struct Client<C: api::ApiClient> {
pub participant: Participant,
interval: time::Interval,
coordinator_pk: CoordinatorPublicKey,
pub has_new_coord_pk_since_last_check: bool,
pub global_model: Option<Model>,
pub cached_model: Option<CachedModel>,
pub has_new_global_model_since_last_check: bool,
pub has_new_global_model_since_last_cache: bool,
pub local_model: Option<Model>,
pub scalar: f64,
id: u32,
client: C,
}
impl<C> Client<C>
where
C: api::ApiClient,
{
pub fn new(period: u64, id: u32, api: C) -> Result<Self, ClientError<C::Error>> {
Ok(Self {
participant: Participant::new().map_err(ClientError::ParticipantInitErr)?,
interval: time::interval(Duration::from_secs(period)),
coordinator_pk: CoordinatorPublicKey::zeroed(),
has_new_coord_pk_since_last_check: false,
global_model: None,
cached_model: None,
has_new_global_model_since_last_check: false,
has_new_global_model_since_last_cache: false,
local_model: None,
scalar: 1.0,
id,
client: api,
})
}
pub async fn start(&mut self) -> Result<(), ClientError<C::Error>> {
loop {
self.during_round().await?;
}
}
pub async fn during_round(&mut self) -> Result<Task, ClientError<C::Error>> {
debug!(client_id = %self.id, "polling for new round parameters");
loop {
let model = self.client.get_model().await?;
match (model, &self.global_model) {
(Some(new_model), None) => self.set_global_model(new_model),
(Some(new_model), Some(old_model)) if &new_model != old_model => {
self.set_global_model(new_model)
}
(None, _) => trace!(client_id = %self.id, "global model not ready yet"),
_ => trace!(client_id = %self.id, "global model still fresh"),
}
let round_params = self.client.get_round_params().await?;
if round_params.pk != self.coordinator_pk {
debug!(client_id = %self.id, "new round parameters received, determining task.");
self.coordinator_pk = round_params.pk;
let round_seed = round_params.seed.as_slice();
self.participant.compute_signatures(round_seed);
let (sum_frac, upd_frac) = (round_params.sum, round_params.update);
let task = self.participant.check_task(sum_frac, upd_frac);
self.has_new_coord_pk_since_last_check = true;
return match task {
Task::Sum => self.summer().await,
Task::Update => self.updater().await,
Task::None => self.unselected().await,
};
} else {
trace!(client_id = %self.id, "still the same round");
}
trace!(client_id = %self.id, "new round parameters not ready, retrying.");
self.interval.tick().await;
}
}
async fn unselected(&mut self) -> Result<Task, ClientError<C::Error>> {
debug!(client_id = %self.id, "not selected");
Ok(Task::None)
}
async fn summer(&mut self) -> Result<Task, ClientError<C::Error>> {
info!(client_id = %self.id, "selected to sum");
let msg = self.participant.compose_sum_message(self.coordinator_pk);
let sealed_msg = self.participant.seal_message(&self.coordinator_pk, &msg);
self.client.send_message(sealed_msg).await?;
debug!(client_id = %self.id, "polling for model/mask length");
let length = loop {
if let Some(length) = self.client.get_mask_length().await? {
if length > usize::MAX as u64 {
return Err(ClientError::ParticipantErr(PetError::InvalidModel));
} else {
break length as usize;
}
}
trace!(client_id = %self.id, "model/mask length not ready, retrying.");
self.interval.tick().await;
};
debug!(client_id = %self.id, "sum message sent, polling for seed dict.");
loop {
if let Some(seeds) = self.client.get_seeds(self.participant.pk).await? {
debug!(client_id = %self.id, "seed dict received, sending sum2 message.");
let msg = self
.participant
.compose_sum2_message(self.coordinator_pk, &seeds, length)
.map_err(|e| {
error!("failed to compose sum2 message with seeds: {:?}", &seeds);
ClientError::ParticipantErr(e)
})?;
let sealed_msg = self.participant.seal_message(&self.coordinator_pk, &msg);
self.client.send_message(sealed_msg).await?;
info!(client_id = %self.id, "sum participant completed a round");
break Ok(Task::Sum);
}
trace!(client_id = %self.id, "seed dict not ready, retrying.");
self.interval.tick().await;
}
}
async fn updater(&mut self) -> Result<Task, ClientError<C::Error>> {
info!(client_id = %self.id, "selected to update");
debug!(client_id = %self.id, "polling for local model");
let model = loop {
if let Some(model) = self.local_model.take() {
self.local_model = Some(model.clone());
break model;
}
trace!(client_id = %self.id, "local model not ready, retrying.");
self.interval.tick().await;
};
debug!(client_id = %self.id, "polling for sum dict");
loop {
if let Some(sums) = self.client.get_sums().await? {
debug!(client_id = %self.id, "sum dict received, sending update message.");
let msg = self.participant.compose_update_message(
self.coordinator_pk,
&sums,
self.scalar,
model,
);
let sealed_msg = self.participant.seal_message(&self.coordinator_pk, &msg);
self.client.send_message(sealed_msg).await?;
info!(client_id = %self.id, "update participant completed a round");
break Ok(Task::Update);
}
trace!(client_id = %self.id, "sum dict not ready, retrying.");
self.interval.tick().await;
}
}
fn set_global_model(&mut self, model: Model) {
debug!(client_id = %self.id, "updating global model");
self.global_model = Some(model);
self.has_new_global_model_since_last_check = true;
self.has_new_global_model_since_last_cache = true;
}
}