use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures::Stream;
use tokio::sync::watch;
use xaynet_core::{
common::RoundParameters,
crypto::EncryptKeyPair,
mask::Model,
SeedDict,
SumDict,
};
use crate::state_machine::phases::PhaseName;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Event<E> {
pub round_id: u64,
pub event: E,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ModelUpdate {
Invalidate,
New(Arc<Model>),
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum MaskLengthUpdate {
Invalidate,
New(usize),
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum DictionaryUpdate<D> {
Invalidate,
New(Arc<D>),
}
#[derive(Debug)]
pub struct EventPublisher {
round_id: u64,
keys_tx: EventBroadcaster<EncryptKeyPair>,
params_tx: EventBroadcaster<RoundParameters>,
phase_tx: EventBroadcaster<PhaseName>,
model_tx: EventBroadcaster<ModelUpdate>,
mask_length_tx: EventBroadcaster<MaskLengthUpdate>,
sum_dict_tx: EventBroadcaster<DictionaryUpdate<SumDict>>,
seed_dict_tx: EventBroadcaster<DictionaryUpdate<SeedDict>>,
}
#[derive(Debug)]
pub struct EventSubscriber {
keys_rx: EventListener<EncryptKeyPair>,
params_rx: EventListener<RoundParameters>,
phase_rx: EventListener<PhaseName>,
model_rx: EventListener<ModelUpdate>,
mask_length_rx: EventListener<MaskLengthUpdate>,
sum_dict_rx: EventListener<DictionaryUpdate<SumDict>>,
seed_dict_rx: EventListener<DictionaryUpdate<SeedDict>>,
}
impl EventPublisher {
pub fn init(
round_id: u64,
keys: EncryptKeyPair,
params: RoundParameters,
phase: PhaseName,
) -> (Self, EventSubscriber) {
let (keys_tx, keys_rx) = watch::channel::<Event<EncryptKeyPair>>(Event {
round_id,
event: keys,
});
let (phase_tx, phase_rx) = watch::channel::<Event<PhaseName>>(Event {
round_id,
event: phase,
});
let (model_tx, model_rx) = watch::channel::<Event<ModelUpdate>>(Event {
round_id,
event: ModelUpdate::Invalidate,
});
let (mask_length_tx, mask_length_rx) = watch::channel::<Event<MaskLengthUpdate>>(Event {
round_id,
event: MaskLengthUpdate::Invalidate,
});
let (sum_dict_tx, sum_dict_rx) =
watch::channel::<Event<DictionaryUpdate<SumDict>>>(Event {
round_id,
event: DictionaryUpdate::Invalidate,
});
let (seed_dict_tx, seed_dict_rx) =
watch::channel::<Event<DictionaryUpdate<SeedDict>>>(Event {
round_id,
event: DictionaryUpdate::Invalidate,
});
let (params_tx, params_rx) = watch::channel::<Event<RoundParameters>>(Event {
round_id,
event: params,
});
let publisher = EventPublisher {
round_id,
keys_tx: keys_tx.into(),
params_tx: params_tx.into(),
phase_tx: phase_tx.into(),
model_tx: model_tx.into(),
mask_length_tx: mask_length_tx.into(),
sum_dict_tx: sum_dict_tx.into(),
seed_dict_tx: seed_dict_tx.into(),
};
let subscriber = EventSubscriber {
keys_rx: keys_rx.into(),
params_rx: params_rx.into(),
phase_rx: phase_rx.into(),
model_rx: model_rx.into(),
mask_length_rx: mask_length_rx.into(),
sum_dict_rx: sum_dict_rx.into(),
seed_dict_rx: seed_dict_rx.into(),
};
(publisher, subscriber)
}
pub fn set_round_id(&mut self, id: u64) {
self.round_id = id;
}
fn event<T>(&self, event: T) -> Event<T> {
Event {
round_id: self.round_id,
event,
}
}
pub fn broadcast_keys(&mut self, keys: EncryptKeyPair) {
let _ = self.keys_tx.broadcast(self.event(keys));
}
pub fn broadcast_params(&mut self, params: RoundParameters) {
let _ = self.params_tx.broadcast(self.event(params));
}
pub fn broadcast_phase(&mut self, phase: PhaseName) {
let _ = self.phase_tx.broadcast(self.event(phase));
}
pub fn broadcast_model(&mut self, update: ModelUpdate) {
let _ = self.model_tx.broadcast(self.event(update));
}
pub fn broadcast_mask_length(&mut self, update: MaskLengthUpdate) {
let _ = self.mask_length_tx.broadcast(self.event(update));
}
pub fn broadcast_sum_dict(&mut self, update: DictionaryUpdate<SumDict>) {
let _ = self.sum_dict_tx.broadcast(self.event(update));
}
pub fn broadcast_seed_dict(&mut self, update: DictionaryUpdate<SeedDict>) {
let _ = self.seed_dict_tx.broadcast(self.event(update));
}
}
impl EventSubscriber {
pub fn keys_listener(&self) -> EventListener<EncryptKeyPair> {
self.keys_rx.clone()
}
pub fn params_listener(&self) -> EventListener<RoundParameters> {
self.params_rx.clone()
}
pub fn phase_listener(&self) -> EventListener<PhaseName> {
self.phase_rx.clone()
}
pub fn model_listener(&self) -> EventListener<ModelUpdate> {
self.model_rx.clone()
}
pub fn mask_length_listener(&self) -> EventListener<MaskLengthUpdate> {
self.mask_length_rx.clone()
}
pub fn sum_dict_listener(&self) -> EventListener<DictionaryUpdate<SumDict>> {
self.sum_dict_rx.clone()
}
pub fn seed_dict_listener(&self) -> EventListener<DictionaryUpdate<SeedDict>> {
self.seed_dict_rx.clone()
}
}
#[derive(Debug, Clone)]
pub struct EventListener<E>(watch::Receiver<Event<E>>);
impl<E> From<watch::Receiver<Event<E>>> for EventListener<E> {
fn from(receiver: watch::Receiver<Event<E>>) -> Self {
EventListener(receiver)
}
}
impl<E> EventListener<E>
where
E: Clone,
{
pub fn get_latest(&self) -> Event<E> {
self.0.borrow().clone()
}
}
impl<E: Clone> Stream for EventListener<E> {
type Item = Event<E>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.0).poll_next(cx)
}
}
#[derive(Debug)]
pub struct EventBroadcaster<E>(watch::Sender<Event<E>>);
impl<E> EventBroadcaster<E> {
fn broadcast(&self, event: Event<E>) {
let _ = self.0.broadcast(event);
}
}
impl<E> From<watch::Sender<Event<E>>> for EventBroadcaster<E> {
fn from(sender: watch::Sender<Event<E>>) -> Self {
Self(sender)
}
}