1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
use async_trait::async_trait;
use derive_more::From;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::{debug, error, info, warn};

use super::{Awaiting, NewRound, SendingSum, SendingSum2, SendingUpdate, Sum, Sum2, Update, IO};
use crate::{
    settings::{MaxMessageSize, PetSettings},
    state_machine::{StateMachine, TransitionOutcome},
    MessageEncoder,
};
use xaynet_core::{
    common::{RoundParameters, RoundSeed},
    crypto::{ByteObject, PublicEncryptKey, SigningKeyPair},
    mask::{self, DataType, MaskConfig, Model},
    message::Payload,
};

/// State of the state machine
#[derive(Debug, Serialize, Deserialize)]
pub struct State<P> {
    /// data specific to the current phase
    pub private: Box<P>,
    /// data common to most of the phases
    pub shared: Box<SharedState>,
}

impl<P> State<P> {
    /// Create a new state
    pub fn new(shared: Box<SharedState>, private: Box<P>) -> Self {
        Self { shared, private }
    }
}

/// A dynamically dispatched [`IO`] object.
pub(crate) type PhaseIo = Box<dyn IO<Model = Box<dyn AsRef<Model> + Send>>>;

/// Represent the state machine in a specific phase
pub struct Phase<P> {
    /// State of the phase.
    pub(super) state: State<P>,
    /// Opaque client for performing IO tasks: talking with the
    /// coordinator API, loading models, etc.
    pub(super) io: PhaseIo,
}

impl<P> std::fmt::Debug for Phase<P>
where
    P: std::fmt::Debug,
{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Phase")
            .field("state", &self.state)
            .field("io", &"PhaseIo")
            .finish()
    }
}

/// Store for all the data that are common to all the phases
#[derive(Serialize, Deserialize, Debug)]
pub struct SharedState {
    /// Keys that identify the participant. They are used to sign the
    /// PET message sent by the participant.
    pub keys: SigningKeyPair,
    /// Scalar used for masking
    pub scalar: f64,
    /// Maximum message size the participant can send. Messages larger
    /// than `message_size` are split in several parts.
    pub message_size: MaxMessageSize,
    /// Current round parameters
    pub round_params: RoundParameters,
}

/// Get arbitrary round parameters. These round parameters are never used, we just
/// temporarily use them in the [`SharedState`] when creating a new state machine. The
/// first thing the state machine does when it runs, is to fetch the real round
/// parameters from the coordinator.
fn dummy_round_parameters() -> RoundParameters {
    RoundParameters {
        pk: PublicEncryptKey::zeroed(),
        sum: 0.0,
        update: 0.0,
        seed: RoundSeed::zeroed(),
        mask_config: MaskConfig {
            group_type: mask::GroupType::Integer,
            data_type: mask::DataType::F32,
            bound_type: mask::BoundType::B0,
            model_type: mask::ModelType::M3,
        }
        .into(),
        model_length: 0,
    }
}

impl SharedState {
    pub fn new(settings: PetSettings) -> Self {
        Self {
            keys: settings.keys,
            scalar: settings.scalar,
            message_size: settings.max_message_size,
            round_params: dummy_round_parameters(),
        }
    }
}

/// A trait that each `Phase<P>` implements. When `Step::step` is called, the phase
/// tries to do a small piece of work.
#[async_trait]
pub trait Step {
    /// Represent an attempt to make progress within a phase. If the step results in a
    /// change in the phase state, the updated state machine is returned as
    /// `TransitionOutcome::Complete`. If no progress can be made, the state machine is
    /// returned unchanged as `TransitionOutcome::Pending`.
    async fn step(mut self) -> TransitionOutcome;
}

#[macro_export]
macro_rules! try_progress {
    ($progress:expr) => {{
        use $crate::state_machine::{Progress, TransitionOutcome};
        match $progress {
            // No progress can be made. Return the state machine as is
            Progress::Stuck(phase) => return TransitionOutcome::Pending(phase.into()),
            // Further progress can be made but require more work, so don't return
            Progress::Continue(phase) => phase,
            // Progress has been made, return the updated state machine
            Progress::Updated(state_machine) => return TransitionOutcome::Complete(state_machine),
        }
    }};
}

/// Represent the presence or absence of progress being made during a phase.
#[derive(Debug)]
pub enum Progress<P> {
    /// No progress can be made currently.
    Stuck(Phase<P>),
    /// More work needs to be done for progress to be made.
    Continue(Phase<P>),
    /// Progress has been made and resulted in this new state machine.
    Updated(StateMachine),
}

impl<P> Phase<P>
where
    Phase<P>: Step + Into<StateMachine>,
{
    /// Try to make some progress in the execution of the PET protocol. There are three
    /// possible outcomes:
    ///
    /// 1. no progress can currently be made and the phase state is unchanged
    /// 2. progress is made but the state machine does not transition to a new
    ///    phase. Internally, the phase state is changed though.
    /// 3. progress is made and the state machine transitions to a new phase.
    ///
    /// In case `1.`, the state machine is returned unchanged, wrapped in
    /// [`TransitionOutcome::Pending`] to indicate to the caller that the state machine
    /// wasn't updated. In case `2.` and `3.` the updated state machine is returned
    /// wrapped in [`TransitionOutcome::Complete`].
    pub async fn step(mut self) -> TransitionOutcome {
        match self.check_round_freshness().await {
            RoundFreshness::Unknown => TransitionOutcome::Pending(self.into()),
            RoundFreshness::Outdated => {
                info!("a new round started: updating the round parameters and resetting the state machine");
                self.io.notify_new_round();
                TransitionOutcome::Complete(
                    Phase::<NewRound>::new(
                        State::new(self.state.shared, Box::new(NewRound)),
                        self.io,
                    )
                    .into(),
                )
            }
            RoundFreshness::Fresh => {
                debug!("round is still fresh, continuing from where we left off");
                <Self as Step>::step(self).await
            }
        }
    }

    /// Check whether the coordinator has published new round parameters. In other
    /// words, this checks whether a new round has started.
    async fn check_round_freshness(&mut self) -> RoundFreshness {
        match self.io.get_round_params().await {
            Err(e) => {
                warn!("failed to fetch round parameters {:?}", e);
                RoundFreshness::Unknown
            }
            Ok(params) => {
                if params == self.state.shared.round_params {
                    debug!("round parameters didn't change");
                    RoundFreshness::Fresh
                } else {
                    info!("fetched fresh round parameters");
                    self.state.shared.round_params = params;
                    RoundFreshness::Outdated
                }
            }
        }
    }
}

/// Trait for building [`Phase<P>`] from a [`State<P>`].
///
/// Note that we could just use [`Phase::new`] for this. However we want to be able to
/// customize the conversion for each phase. For instance, when building a
/// `Phase<Update>` from an `Update`, we want to emit some events with the `io`
/// object. It is cleaner to wrap this custom logic in a trait impl.
pub(crate) trait IntoPhase<P> {
    /// Build the phase with the given `io` object
    fn into_phase(self, io: PhaseIo) -> Phase<P>;
}

impl<P> Phase<P> {
    /// Build a new phase with the given state and io object. This should not be called
    /// directly. Instead, use the [`IntoPhase`] trait to construct a phase.
    pub(crate) fn new(state: State<P>, io: PhaseIo) -> Self {
        Phase { state, io }
    }

    /// Instantiate a message encoder for the given payload.
    ///
    /// The encoder takes care of converting the given `payload` into one or several
    /// signed and encrypted PET messages.
    pub fn message_encoder(&self, payload: Payload) -> MessageEncoder {
        MessageEncoder::new(
            self.state.shared.keys.clone(),
            payload,
            self.state.shared.round_params.pk,
            self.state
                .shared
                .message_size
                .max_payload_size()
                .unwrap_or(0),
        )
        // the encoder rejects Chunk payload, but in the state
        // machine, we never manually create such payloads so
        // unwrapping is fine
        .unwrap()
    }

    /// Return the local model configuration of the model that is expected in the update phase.
    pub fn local_model_config(&self) -> LocalModelConfig {
        LocalModelConfig {
            data_type: self.state.shared.round_params.mask_config.vect.data_type,
            len: self.state.shared.round_params.model_length,
        }
    }

    #[cfg(test)]
    pub(crate) fn with_io_mock<F>(&mut self, f: F)
    where
        F: FnOnce(&mut super::MockIO),
    {
        let mut mock = super::MockIO::new();
        f(&mut mock);
        self.io = Box::new(mock);
    }

    #[cfg(test)]
    pub(crate) fn check_io_mock(&mut self) {
        // dropping the mock forces the checks to run. We replace it
        // by an empty one, so that we detect if a method is called
        // un-expectedly afterwards
        let _ = std::mem::replace(&mut self.io, Box::new(super::MockIO::new()));
    }
}

#[derive(Debug)]
/// The local model configuration of the model that is expected in the update phase.
pub struct LocalModelConfig {
    /// The expected data type of the local model.
    // In the current state it is not possible to configure a coordinator in which
    // the scalar data type and the model data type are different. Therefore, we assume here
    // that the scalar data type is the same as the model data type.
    pub data_type: DataType,
    /// The expected length of the local model.
    pub len: usize,
}

#[derive(Error, Debug)]
#[error("failed to send a PET message")]
pub struct SendMessageError;

/// Round freshness indicator
pub enum RoundFreshness {
    /// A new round started. The current round is outdated
    Outdated,
    /// We were not able to check whether a new round started
    Unknown,
    /// The current round is still going
    Fresh,
}

/// A serializable representation of a phase state.
///
/// We cannot serialize the state directly, even though it implements `Serialize`, because deserializing it would require knowing its type in advance:
///
/// ```ignore
/// // `buf` is a Vec<u8> that contains a serialized state that we want to deserialize
/// let state: State<???> = State::deserialize(&buf[..]).unwrap();
/// ```
#[derive(Serialize, Deserialize, From, Debug)]
pub enum SerializableState {
    NewRound(State<NewRound>),
    Awaiting(State<Awaiting>),
    Sum(State<Sum>),
    Update(State<Update>),
    Sum2(State<Sum2>),
    SendingSum(State<SendingSum>),
    SendingUpdate(State<SendingUpdate>),
    SendingSum2(State<SendingSum2>),
}

impl<P> Into<SerializableState> for Phase<P>
where
    State<P>: Into<SerializableState>,
{
    fn into(self) -> SerializableState {
        self.state.into()
    }
}