1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
use async_trait::async_trait; use derive_more::From; use serde::{Deserialize, Serialize}; use thiserror::Error; use tracing::{debug, error, info, warn}; use super::{Awaiting, NewRound, SendingSum, SendingSum2, SendingUpdate, Sum, Sum2, Update, IO}; use crate::{ settings::{MaxMessageSize, PetSettings}, state_machine::{StateMachine, TransitionOutcome}, MessageEncoder, }; use xaynet_core::{ common::{RoundParameters, RoundSeed}, crypto::{ByteObject, PublicEncryptKey, SigningKeyPair}, mask::{self, DataType, MaskConfig, Model}, message::Payload, }; /// State of the state machine #[derive(Debug, Serialize, Deserialize)] pub struct State<P> { /// data specific to the current phase pub private: Box<P>, /// data common to most of the phases pub shared: Box<SharedState>, } impl<P> State<P> { /// Create a new state pub fn new(shared: Box<SharedState>, private: Box<P>) -> Self { Self { shared, private } } } /// A dynamically dispatched [`IO`] object. pub(crate) type PhaseIo = Box<dyn IO<Model = Box<dyn AsRef<Model> + Send>>>; /// Represent the state machine in a specific phase pub struct Phase<P> { /// State of the phase. pub(super) state: State<P>, /// Opaque client for performing IO tasks: talking with the /// coordinator API, loading models, etc. pub(super) io: PhaseIo, } impl<P> std::fmt::Debug for Phase<P> where P: std::fmt::Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Phase") .field("state", &self.state) .field("io", &"PhaseIo") .finish() } } /// Store for all the data that are common to all the phases #[derive(Serialize, Deserialize, Debug)] pub struct SharedState { /// Keys that identify the participant. They are used to sign the /// PET message sent by the participant. pub keys: SigningKeyPair, /// Scalar used for masking pub scalar: f64, /// Maximum message size the participant can send. Messages larger /// than `message_size` are split in several parts. pub message_size: MaxMessageSize, /// Current round parameters pub round_params: RoundParameters, } /// Get arbitrary round parameters. These round parameters are never used, we just /// temporarily use them in the [`SharedState`] when creating a new state machine. The /// first thing the state machine does when it runs, is to fetch the real round /// parameters from the coordinator. fn dummy_round_parameters() -> RoundParameters { RoundParameters { pk: PublicEncryptKey::zeroed(), sum: 0.0, update: 0.0, seed: RoundSeed::zeroed(), mask_config: MaskConfig { group_type: mask::GroupType::Integer, data_type: mask::DataType::F32, bound_type: mask::BoundType::B0, model_type: mask::ModelType::M3, } .into(), model_length: 0, } } impl SharedState { pub fn new(settings: PetSettings) -> Self { Self { keys: settings.keys, scalar: settings.scalar, message_size: settings.max_message_size, round_params: dummy_round_parameters(), } } } /// A trait that each `Phase<P>` implements. When `Step::step` is called, the phase /// tries to do a small piece of work. #[async_trait] pub trait Step { /// Represent an attempt to make progress within a phase. If the step results in a /// change in the phase state, the updated state machine is returned as /// `TransitionOutcome::Complete`. If no progress can be made, the state machine is /// returned unchanged as `TransitionOutcome::Pending`. async fn step(mut self) -> TransitionOutcome; } #[macro_export] macro_rules! try_progress { ($progress:expr) => {{ use $crate::state_machine::{Progress, TransitionOutcome}; match $progress { // No progress can be made. Return the state machine as is Progress::Stuck(phase) => return TransitionOutcome::Pending(phase.into()), // Further progress can be made but require more work, so don't return Progress::Continue(phase) => phase, // Progress has been made, return the updated state machine Progress::Updated(state_machine) => return TransitionOutcome::Complete(state_machine), } }}; } /// Represent the presence or absence of progress being made during a phase. #[derive(Debug)] pub enum Progress<P> { /// No progress can be made currently. Stuck(Phase<P>), /// More work needs to be done for progress to be made. Continue(Phase<P>), /// Progress has been made and resulted in this new state machine. Updated(StateMachine), } impl<P> Phase<P> where Phase<P>: Step + Into<StateMachine>, { /// Try to make some progress in the execution of the PET protocol. There are three /// possible outcomes: /// /// 1. no progress can currently be made and the phase state is unchanged /// 2. progress is made but the state machine does not transition to a new /// phase. Internally, the phase state is changed though. /// 3. progress is made and the state machine transitions to a new phase. /// /// In case `1.`, the state machine is returned unchanged, wrapped in /// [`TransitionOutcome::Pending`] to indicate to the caller that the state machine /// wasn't updated. In case `2.` and `3.` the updated state machine is returned /// wrapped in [`TransitionOutcome::Complete`]. pub async fn step(mut self) -> TransitionOutcome { match self.check_round_freshness().await { RoundFreshness::Unknown => TransitionOutcome::Pending(self.into()), RoundFreshness::Outdated => { info!("a new round started: updating the round parameters and resetting the state machine"); self.io.notify_new_round(); TransitionOutcome::Complete( Phase::<NewRound>::new( State::new(self.state.shared, Box::new(NewRound)), self.io, ) .into(), ) } RoundFreshness::Fresh => { debug!("round is still fresh, continuing from where we left off"); <Self as Step>::step(self).await } } } /// Check whether the coordinator has published new round parameters. In other /// words, this checks whether a new round has started. async fn check_round_freshness(&mut self) -> RoundFreshness { match self.io.get_round_params().await { Err(e) => { warn!("failed to fetch round parameters {:?}", e); RoundFreshness::Unknown } Ok(params) => { if params == self.state.shared.round_params { debug!("round parameters didn't change"); RoundFreshness::Fresh } else { info!("fetched fresh round parameters"); self.state.shared.round_params = params; RoundFreshness::Outdated } } } } } /// Trait for building [`Phase<P>`] from a [`State<P>`]. /// /// Note that we could just use [`Phase::new`] for this. However we want to be able to /// customize the conversion for each phase. For instance, when building a /// `Phase<Update>` from an `Update`, we want to emit some events with the `io` /// object. It is cleaner to wrap this custom logic in a trait impl. pub(crate) trait IntoPhase<P> { /// Build the phase with the given `io` object fn into_phase(self, io: PhaseIo) -> Phase<P>; } impl<P> Phase<P> { /// Build a new phase with the given state and io object. This should not be called /// directly. Instead, use the [`IntoPhase`] trait to construct a phase. pub(crate) fn new(state: State<P>, io: PhaseIo) -> Self { Phase { state, io } } /// Instantiate a message encoder for the given payload. /// /// The encoder takes care of converting the given `payload` into one or several /// signed and encrypted PET messages. pub fn message_encoder(&self, payload: Payload) -> MessageEncoder { MessageEncoder::new( self.state.shared.keys.clone(), payload, self.state.shared.round_params.pk, self.state .shared .message_size .max_payload_size() .unwrap_or(0), ) // the encoder rejects Chunk payload, but in the state // machine, we never manually create such payloads so // unwrapping is fine .unwrap() } /// Return the local model configuration of the model that is expected in the update phase. pub fn local_model_config(&self) -> LocalModelConfig { LocalModelConfig { data_type: self.state.shared.round_params.mask_config.vect.data_type, len: self.state.shared.round_params.model_length, } } #[cfg(test)] pub(crate) fn with_io_mock<F>(&mut self, f: F) where F: FnOnce(&mut super::MockIO), { let mut mock = super::MockIO::new(); f(&mut mock); self.io = Box::new(mock); } #[cfg(test)] pub(crate) fn check_io_mock(&mut self) { // dropping the mock forces the checks to run. We replace it // by an empty one, so that we detect if a method is called // un-expectedly afterwards let _ = std::mem::replace(&mut self.io, Box::new(super::MockIO::new())); } } #[derive(Debug)] /// The local model configuration of the model that is expected in the update phase. pub struct LocalModelConfig { /// The expected data type of the local model. // In the current state it is not possible to configure a coordinator in which // the scalar data type and the model data type are different. Therefore, we assume here // that the scalar data type is the same as the model data type. pub data_type: DataType, /// The expected length of the local model. pub len: usize, } #[derive(Error, Debug)] #[error("failed to send a PET message")] pub struct SendMessageError; /// Round freshness indicator pub enum RoundFreshness { /// A new round started. The current round is outdated Outdated, /// We were not able to check whether a new round started Unknown, /// The current round is still going Fresh, } /// A serializable representation of a phase state. /// /// We cannot serialize the state directly, even though it implements `Serialize`, because deserializing it would require knowing its type in advance: /// /// ```ignore /// // `buf` is a Vec<u8> that contains a serialized state that we want to deserialize /// let state: State<???> = State::deserialize(&buf[..]).unwrap(); /// ``` #[derive(Serialize, Deserialize, From, Debug)] pub enum SerializableState { NewRound(State<NewRound>), Awaiting(State<Awaiting>), Sum(State<Sum>), Update(State<Update>), Sum2(State<Sum2>), SendingSum(State<SendingSum>), SendingUpdate(State<SendingUpdate>), SendingSum2(State<SendingSum2>), } impl<P> Into<SerializableState> for Phase<P> where State<P>: Into<SerializableState>, { fn into(self) -> SerializableState { self.state.into() } }