Skip to main content

statefun_sdk/transport/
hyper.rs

1//! `Transport` that uses [Hyper](http://docs.rs/hyper) to serve stateful functions.
2use std::convert::Infallible;
3use std::net::SocketAddr;
4use std::sync::{Arc, Mutex};
5
6use crate::invocation_bridge::InvocationBridge;
7use bytes::buf::BufExt;
8use hyper::service::{make_service_fn, service_fn};
9use hyper::{Body, Request, Response, Server};
10use protobuf::Message;
11use tokio::runtime;
12
13use statefun_proto::http_function::ToFunction;
14
15use crate::function_registry::FunctionRegistry;
16use crate::transport::Transport;
17
18/// A [Transport](crate::transport::Transport) that serves stateful functions on a http endpoint at
19/// the given `bind_address`.
20pub struct HyperHttpTransport {
21    bind_address: SocketAddr,
22}
23
24impl HyperHttpTransport {
25    /// Creates a new `HyperHttpTransport` that can serve stateful functions at the given
26    /// `bind_address`.
27    pub fn new(bind_address: SocketAddr) -> HyperHttpTransport {
28        HyperHttpTransport { bind_address }
29    }
30}
31
32impl Transport for HyperHttpTransport {
33    fn run(self, function_registry: FunctionRegistry) -> Result<(), failure::Error> {
34        log::info!(
35            "Hyper transport will start listening on {}",
36            self.bind_address
37        );
38
39        let mut runtime = runtime::Builder::new()
40            .threaded_scheduler()
41            .enable_all()
42            .build()?;
43
44        let function_registry = Arc::new(Mutex::new(function_registry));
45
46        runtime.block_on(async {
47            let make_svc = make_service_fn(|_conn| {
48                let function_registry = Arc::clone(&function_registry);
49                async move {
50                    Ok::<_, Infallible>(service_fn(move |req: Request<Body>| {
51                        let function_registry = Arc::clone(&function_registry);
52                        async move { handle_request(function_registry, req).await }
53                    }))
54                }
55            });
56            let server = Server::bind(&self.bind_address).serve(make_svc);
57            let graceful = server.with_graceful_shutdown(shutdown_signal());
58
59            if let Err(e) = graceful.await {
60                eprintln!("server error: {}", e);
61            }
62        });
63
64        Ok(())
65    }
66}
67
68async fn handle_request(
69    function_registry: Arc<Mutex<FunctionRegistry>>,
70    req: Request<Body>,
71) -> Result<Response<Body>, failure::Error> {
72    let (_parts, body) = req.into_parts();
73    log::debug!("Parts {:#?}", _parts);
74
75    let full_body = hyper::body::to_bytes(body).await?;
76    let to_function: ToFunction = protobuf::parse_from_reader(&mut full_body.reader())?;
77    let from_function = {
78        let function_registry = function_registry.lock().unwrap();
79        function_registry.invoke_old(to_function)?
80    };
81
82    log::debug!("Response: {:#?}", from_function);
83
84    let encoded_result = from_function.write_to_bytes()?;
85
86    let response = Response::builder()
87        .header("content-type", "application/octet-stream")
88        .body(encoded_result.into())?;
89
90    log::debug!("Succesfully encoded response.");
91
92    Ok(response)
93}
94
95async fn shutdown_signal() {
96    tokio::signal::ctrl_c()
97        .await
98        .expect("failed to install CTRL+C signal handler");
99}