mod mask_length;
mod model;
mod round_parameters;
mod seed_dict;
mod sum_dict;
pub use self::{
mask_length::{MaskLengthRequest, MaskLengthResponse, MaskLengthService},
model::{ModelRequest, ModelResponse, ModelService},
round_parameters::{RoundParamsRequest, RoundParamsResponse, RoundParamsService},
seed_dict::{SeedDictRequest, SeedDictResponse, SeedDictService},
sum_dict::{SumDictRequest, SumDictResponse, SumDictService},
};
use std::task::{Context, Poll};
use futures::future::poll_fn;
use tower::{layer::Layer, Service, ServiceBuilder};
use crate::state_machine::events::EventSubscriber;
#[async_trait]
pub trait Fetcher {
async fn round_params(&mut self) -> Result<RoundParamsResponse, FetchError>;
async fn mask_length(&mut self) -> Result<MaskLengthResponse, FetchError>;
async fn model(&mut self) -> Result<ModelResponse, FetchError>;
async fn seed_dict(&mut self) -> Result<SeedDictResponse, FetchError>;
async fn sum_dict(&mut self) -> Result<SumDictResponse, FetchError>;
}
pub type FetchError = anyhow::Error;
fn into_fetch_error<E: Into<Box<dyn ::std::error::Error + 'static + Sync + Send>>>(
e: E,
) -> FetchError {
anyhow::anyhow!("Fetcher failed: {:?}", e.into())
}
#[async_trait]
impl<RoundParams, SumDict, SeedDict, MaskLength, Model> Fetcher
for Fetchers<RoundParams, SumDict, SeedDict, MaskLength, Model>
where
Self: Send + Sync + 'static,
RoundParams: Service<RoundParamsRequest, Response = RoundParamsResponse> + Send + 'static,
<RoundParams as Service<RoundParamsRequest>>::Future: Send + Sync + 'static,
<RoundParams as Service<RoundParamsRequest>>::Error:
Into<Box<dyn ::std::error::Error + 'static + Sync + Send>>,
MaskLength: Service<MaskLengthRequest, Response = MaskLengthResponse> + Send + 'static,
<MaskLength as Service<MaskLengthRequest>>::Future: Send + Sync + 'static,
<MaskLength as Service<MaskLengthRequest>>::Error:
Into<Box<dyn ::std::error::Error + 'static + Sync + Send>>,
Model: Service<ModelRequest, Response = ModelResponse> + Send + 'static,
<Model as Service<ModelRequest>>::Future: Send + Sync + 'static,
<Model as Service<ModelRequest>>::Error:
Into<Box<dyn ::std::error::Error + 'static + Sync + Send>>,
SeedDict: Service<SeedDictRequest, Response = SeedDictResponse> + Send + 'static,
<SeedDict as Service<SeedDictRequest>>::Future: Send + Sync + 'static,
<SeedDict as Service<SeedDictRequest>>::Error:
Into<Box<dyn ::std::error::Error + 'static + Sync + Send>>,
SumDict: Service<SumDictRequest, Response = SumDictResponse> + Send + 'static,
<SumDict as Service<SumDictRequest>>::Future: Send + Sync + 'static,
<SumDict as Service<SumDictRequest>>::Error:
Into<Box<dyn ::std::error::Error + 'static + Sync + Send>>,
{
async fn round_params(&mut self) -> Result<RoundParamsResponse, FetchError> {
poll_fn(|cx| {
<RoundParams as Service<RoundParamsRequest>>::poll_ready(&mut self.round_params, cx)
})
.await
.map_err(into_fetch_error)?;
Ok(<RoundParams as Service<RoundParamsRequest>>::call(
&mut self.round_params,
RoundParamsRequest,
)
.await
.map_err(into_fetch_error)?)
}
async fn mask_length(&mut self) -> Result<MaskLengthResponse, FetchError> {
poll_fn(|cx| {
<MaskLength as Service<MaskLengthRequest>>::poll_ready(&mut self.mask_length, cx)
})
.await
.map_err(into_fetch_error)?;
Ok(<MaskLength as Service<MaskLengthRequest>>::call(
&mut self.mask_length,
MaskLengthRequest,
)
.await
.map_err(into_fetch_error)?)
}
async fn model(&mut self) -> Result<ModelResponse, FetchError> {
poll_fn(|cx| <Model as Service<ModelRequest>>::poll_ready(&mut self.model, cx))
.await
.map_err(into_fetch_error)?;
Ok(
<Model as Service<ModelRequest>>::call(&mut self.model, ModelRequest)
.await
.map_err(into_fetch_error)?,
)
}
async fn seed_dict(&mut self) -> Result<SeedDictResponse, FetchError> {
poll_fn(|cx| <SeedDict as Service<SeedDictRequest>>::poll_ready(&mut self.seed_dict, cx))
.await
.map_err(into_fetch_error)?;
Ok(
<SeedDict as Service<SeedDictRequest>>::call(&mut self.seed_dict, SeedDictRequest)
.await
.map_err(into_fetch_error)?,
)
}
async fn sum_dict(&mut self) -> Result<SumDictResponse, FetchError> {
poll_fn(|cx| <SumDict as Service<SumDictRequest>>::poll_ready(&mut self.sum_dict, cx))
.await
.map_err(into_fetch_error)?;
Ok(
<SumDict as Service<SumDictRequest>>::call(&mut self.sum_dict, SumDictRequest)
.await
.map_err(into_fetch_error)?,
)
}
}
pub(in crate::services) struct FetcherService<S>(S);
impl<S, R> Service<R> for FetcherService<S>
where
S: Service<R>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
fn call(&mut self, req: R) -> Self::Future {
self.0.call(req)
}
}
pub(in crate::services) struct FetcherLayer;
impl<S> Layer<S> for FetcherLayer {
type Service = FetcherService<S>;
fn layer(&self, service: S) -> Self::Service {
FetcherService(service)
}
}
#[derive(Debug, Clone)]
pub struct Fetchers<RoundParams, SumDict, SeedDict, MaskLength, Model> {
round_params: RoundParams,
sum_dict: SumDict,
seed_dict: SeedDict,
mask_length: MaskLength,
model: Model,
}
impl<RoundParams, SumDict, SeedDict, MaskLength, Model>
Fetchers<RoundParams, SumDict, SeedDict, MaskLength, Model>
{
pub fn new(
round_params: RoundParams,
sum_dict: SumDict,
seed_dict: SeedDict,
mask_length: MaskLength,
model: Model,
) -> Self {
Self {
round_params,
sum_dict,
seed_dict,
mask_length,
model,
}
}
}
pub fn fetcher(event_subscriber: &EventSubscriber) -> impl Fetcher + Sync + Send + Clone + 'static {
let round_params = ServiceBuilder::new()
.buffer(100)
.concurrency_limit(100)
.layer(FetcherLayer)
.service(RoundParamsService::new(event_subscriber));
let mask_length = ServiceBuilder::new()
.buffer(100)
.concurrency_limit(100)
.layer(FetcherLayer)
.service(MaskLengthService::new(event_subscriber));
let model = ServiceBuilder::new()
.buffer(100)
.concurrency_limit(100)
.layer(FetcherLayer)
.service(ModelService::new(event_subscriber));
let sum_dict = ServiceBuilder::new()
.buffer(100)
.concurrency_limit(100)
.layer(FetcherLayer)
.service(SumDictService::new(event_subscriber));
let seed_dict = ServiceBuilder::new()
.buffer(100)
.concurrency_limit(100)
.layer(FetcherLayer)
.service(SeedDictService::new(event_subscriber));
Fetchers::new(round_params, sum_dict, seed_dict, mask_length, model)
}