1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
use crate::{
    aggregator::service::{Aggregator, ServiceHandle},
    common::client::{ClientId, Credentials, Token},
};
use tokio::net::TcpListener;
use tracing_futures::Instrument;
use warp::{
    http::{header::CONTENT_TYPE, method::Method, Response, StatusCode},
    Filter,
};

pub async fn serve<A: Aggregator + 'static>(bind_address: &str, handle: ServiceHandle<A>) {
    let handle = warp::any().map(move || handle.clone());
    let parent_span = tracing::Span::current();

    let download_global_weights = warp::get()
        .and(warp::path::param::<ClientId>())
        .and(warp::path::param::<Token>())
        .and(handle.clone())
        .and_then(move |id, token, handle: ServiceHandle<A>| {
            let span =
                trace_span!(parent: parent_span.clone(), "api_download_request", client_id = %id);
            async move {
                debug!("received download request");
                match handle.download(Credentials(id, token)).await {
                    Ok(weights) => Ok(Response::builder().body(weights)),
                    Err(_) => Err(warp::reject::not_found()),
                }
            }
            .instrument(span)
        })
        .with(warp::cors().allow_any_origin().allow_method(Method::GET))
        // We need to send the this content type back, otherwise the swagger ui does not understand
        // that the data is binary data.
        // Without the "content-type", swagger will show the data as text.
        .with(warp::reply::with::header(
            "Content-Type",
            "application/octet-stream",
        ));

    let parent_span = tracing::Span::current();
    let upload_local_weights = warp::post()
        .and(warp::path::param::<ClientId>())
        .and(warp::path::param::<Token>())
        .and(warp::body::bytes())
        .and(handle.clone())
        .and_then(move |id, token, weights, handle: ServiceHandle<A>| {
            let span =
                trace_span!(parent: parent_span.clone(), "api_upload_request", client_id = %id);

            async move {
                debug!("received upload request");
                match handle.upload(Credentials(id, token), weights).await {
                    Ok(()) => Ok(StatusCode::OK),
                    Err(_) => Err(warp::reject::not_found()),
                }
            }
            .instrument(span)
        })
        .with(
            warp::cors()
                .allow_any_origin()
                .allow_method(Method::POST)
                // Allow the "content-typ" header which is requested in the CORS preflight request.
                // Without this header, we will get an CORS error in the swagger ui.
                .allow_header(CONTENT_TYPE),
        );

    let mut listener = TcpListener::bind(bind_address).await.unwrap();

    info!("starting HTTP server on {}", bind_address);
    let log = warp::log("http");
    warp::serve(download_global_weights.or(upload_local_weights).with(log))
        .run_incoming(listener.incoming())
        .await
}