1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
//! `Transport` that uses [Hyper](http://docs.rs/hyper) to serve stateful functions.
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};

use crate::invocation_bridge::InvocationBridge;
use bytes::buf::BufExt;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response, Server};
use protobuf::Message;
use tokio::runtime;

use statefun_proto::http_function::ToFunction;

use crate::function_registry::FunctionRegistry;
use crate::transport::Transport;

/// A [Transport](crate::transport::Transport) that serves stateful functions on a http endpoint at
/// the given `bind_address`.
pub struct HyperHttpTransport {
    bind_address: SocketAddr,
}

impl HyperHttpTransport {
    /// Creates a new `HyperHttpTransport` that can serve stateful functions at the given
    /// `bind_address`.
    pub fn new(bind_address: SocketAddr) -> HyperHttpTransport {
        HyperHttpTransport { bind_address }
    }
}

impl Transport for HyperHttpTransport {
    fn run(self, function_registry: FunctionRegistry) -> Result<(), failure::Error> {
        log::info!(
            "Hyper transport will start listening on {}",
            self.bind_address
        );

        let mut runtime = runtime::Builder::new()
            .threaded_scheduler()
            .enable_all()
            .build()?;

        let function_registry = Arc::new(Mutex::new(function_registry));

        runtime.block_on(async {
            let make_svc = make_service_fn(|_conn| {
                let function_registry = Arc::clone(&function_registry);
                async move {
                    Ok::<_, Infallible>(service_fn(move |req: Request<Body>| {
                        let function_registry = Arc::clone(&function_registry);
                        async move { handle_request(function_registry, req).await }
                    }))
                }
            });
            let server = Server::bind(&self.bind_address).serve(make_svc);
            let graceful = server.with_graceful_shutdown(shutdown_signal());

            if let Err(e) = graceful.await {
                eprintln!("server error: {}", e);
            }
        });

        Ok(())
    }
}

async fn handle_request(
    function_registry: Arc<Mutex<FunctionRegistry>>,
    req: Request<Body>,
) -> Result<Response<Body>, failure::Error> {
    let (_parts, body) = req.into_parts();
    log::debug!("Parts {:#?}", _parts);

    let full_body = hyper::body::to_bytes(body).await?;
    let to_function: ToFunction = protobuf::parse_from_reader(&mut full_body.reader())?;
    let from_function = {
        let function_registry = function_registry.lock().unwrap();
        function_registry.invoke_from_proto(to_function)?
    };

    log::debug!("Response: {:#?}", from_function);

    let encoded_result = from_function.write_to_bytes()?;

    let response = Response::builder()
        .header("content-type", "application/octet-stream")
        .body(encoded_result.into())?;

    log::debug!("Succesfully encoded response.");

    Ok(response)
}

async fn shutdown_signal() {
    tokio::signal::ctrl_c()
        .await
        .expect("failed to install CTRL+C signal handler");
}