webtonic_server/
lib.rs

1//! Server crate of the [`WebTonic`](https://github.com/Sawchord/webtonic) project.
2//!
3//! This crate only contains the [`Server`](Server).
4//! This is necessary, in order to unpack the requests, the client has sent over the websocket connection.
5//! It is designed to mimic the
6//! [`Tonic`](https://docs.rs/tonic/0.3.1/tonic/transport/struct.Server.html) implementation.
7
8use bytes::{Bytes, BytesMut};
9use core::{
10    marker::{Send, Sync},
11    task::Context,
12    task::Poll,
13};
14use futures::{future, StreamExt};
15use http::{request::Request, response::Response};
16use prost::Message as ProstMessage;
17use std::net::SocketAddr;
18use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
19use tokio_stream::wrappers::UnboundedReceiverStream;
20use tonic::{body::BoxBody, codegen::Never, transport::NamedService, Status};
21use tower_service::Service;
22use warp::{
23    ws::{Message, WebSocket},
24    Filter,
25};
26use webtonic_proto::Call;
27
28/// The server endpoint of the `WebTonic` websocket bridge.
29///
30/// This is designet to be used similar to the
31/// [`Tonic`](https://github.com/hyperium/tonic/tree/master/tonic/src/transport) implementation.
32///
33/// # Example
34/// Assuming we have the
35/// [greeter example](https://github.com/hyperium/tonic/blob/master/examples/proto/helloworld/helloworld.proto)
36/// in scope, we can serve an endpoint like so:
37/// ```
38/// let greeter = MyGreeter::default();
39///
40/// webtonic_server::Server::builder()
41///     .add_service(GreeterServer::new(greeter))
42///     .serve(([127, 0, 0, 1], 8080))
43///     .await;
44/// ```
45#[derive(Debug, Clone)]
46pub struct Server {}
47
48impl Server {
49    /// Create a new [`Server`](Server) builder.
50    ///
51    /// # Returns
52    /// A [`Server`](Server) in default configuration.
53    pub fn builder() -> Self {
54        Self {}
55    }
56
57    /// [service]: https://docs.rs/tower-service/0.3.0/tower_service/trait.Service.html
58    /// Add a [`Service`][service] to the route (see [example](Server)).
59    ///
60    /// # Arguments
61    /// - `service`: the [`Service`][service] to add
62    ///
63    /// # Returns
64    /// - A [`Router`](Router), which included the old routes and the new service.
65    /// This also means you need to finish server configuration before calling this function.
66    pub fn add_service<A>(self, service: A) -> Router<A, Unimplemented>
67    where
68        A: Service<Request<BoxBody>, Response = Response<BoxBody>> + Sync + Send + 'static,
69    {
70        Router {
71            server: self,
72            root: Route(service, Unimplemented),
73        }
74    }
75}
76
77/// A [`Router`](Router) is used to compile [`Routes`](Route), by [adding services](Router::add_service).
78#[derive(Debug, Clone)]
79pub struct Router<A, B> {
80    server: Server,
81    root: Route<A, B>,
82}
83
84impl<A, B> Router<A, B> {
85    /// [service]: https://docs.rs/tower-service/0.3.0/tower_service/trait.Service.html
86    /// Add a [`Service`][service] to the route (see [example](Server)).
87    ///
88    /// # Arguments
89    /// - `service`: the [`Service`][service] to add
90    ///
91    /// # Returns
92    /// - A new [`Router`](Router), which included the old routes and the new service.
93    pub fn add_service<C>(self, service: C) -> Router<C, Route<A, B>>
94    where
95        C: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Never>,
96    {
97        Router {
98            server: self.server,
99            root: Route(service, self.root),
100        }
101    }
102
103    /// Start serving the endpoint on the provided addres (see [example](Server)).
104    ///
105    /// # Arguments
106    /// - `addr`: The address on which to serve the endpoint.
107    ///
108    /// # Returns
109    /// - It doens't.
110    pub async fn serve<U>(self, addr: U)
111    where
112        U: Into<SocketAddr>,
113        A: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Never>
114            + NamedService
115            + Clone
116            + Send
117            + Sync
118            + 'static,
119        A::Future: Send + 'static,
120        B: Service<(String, Request<BoxBody>), Response = Response<BoxBody>, Error = Never>
121            + Clone
122            + Send
123            + Sync
124            + 'static,
125        B::Future: Send + 'static,
126    {
127        let server_clone = warp::any().map(move || self.clone());
128
129        warp::serve(warp::path::end().and(warp::ws()).and(server_clone).map(
130            |ws: warp::ws::Ws, server_clone| {
131                ws.on_upgrade(|socket| handle_connection2(socket, server_clone))
132            },
133        ))
134        .run(addr)
135        .await;
136    }
137}
138
139/// Representation of a gRPC route.
140///
141/// You will likely not interact with this directly, but rather through the [`Server`](Server)
142/// and [`Router`](Router) structs.
143#[derive(Debug, Clone)]
144pub struct Route<A, B>(A, B);
145
146impl<A, B> Service<(String, Request<BoxBody>)> for Route<A, B>
147where
148    A: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Never> + NamedService,
149    A::Future: Send + 'static,
150    B: Service<(String, Request<BoxBody>), Response = Response<BoxBody>, Error = Never>,
151    B::Future: Send + 'static,
152{
153    type Response = Response<BoxBody>;
154    type Error = Never;
155    type Future = future::Either<A::Future, B::Future>;
156
157    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
158        Ok(()).into()
159    }
160
161    fn call(&mut self, req: (String, Request<BoxBody>)) -> Self::Future {
162        if req.0.eq(<A as NamedService>::NAME) {
163            future::Either::Left(self.0.call(req.1))
164        } else {
165            future::Either::Right(self.1.call((req.0, req.1)))
166        }
167    }
168}
169
170/// The unimplemented service sends `unimplemented` errors on any request.
171///
172/// This is used as the fallthrough route in gRPC.
173#[derive(Default, Clone, Debug)]
174pub struct Unimplemented;
175
176impl Service<(String, Request<BoxBody>)> for Unimplemented {
177    type Response = Response<BoxBody>;
178    type Error = Never;
179    type Future = future::Ready<Result<Self::Response, Self::Error>>;
180
181    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
182        Ok(()).into()
183    }
184
185    fn call(&mut self, _req: (String, Request<BoxBody>)) -> Self::Future {
186        future::ok(
187            http::Response::builder()
188                .status(200)
189                .header("grpc-status", "12")
190                .header("content-type", "application/grpc")
191                .body(BoxBody::empty())
192                .unwrap(),
193        )
194    }
195}
196
197async fn handle_connection2<A, B>(ws: WebSocket, routes: Router<A, B>)
198where
199    A: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Never>
200        + NamedService
201        + Clone,
202    A::Future: Send + 'static,
203    B: Service<(String, Request<BoxBody>), Response = Response<BoxBody>, Error = Never> + Clone,
204    B::Future: Send + 'static,
205{
206    log::debug!("opening a new connection");
207
208    let (ws_tx, mut ws_rx) = ws.split();
209    let (tx, rx) = unbounded_channel();
210    // Create outbound task
211    tokio::task::spawn(UnboundedReceiverStream::new(rx).forward(ws_tx));
212
213    while let Some(msg) = ws_rx.next().await {
214        log::debug!("received message {:?}", msg);
215
216        // Try to send status error
217        // If even that fails, end task
218        macro_rules! status_err {
219            ($status: expr) => {
220                match return_status(&tx, $status).await {
221                    true => continue,
222                    false => break,
223                }
224            };
225        }
226
227        // Check that we got a message and it is binary
228        let msg = match msg {
229            Ok(msg) => {
230                if msg.is_binary() {
231                    Bytes::from(msg.into_bytes())
232                } else if msg.is_close() {
233                    log::debug!("channel was closed");
234                    break;
235                } else {
236                    status_err!(Status::invalid_argument(
237                        "websocket messages must be sent in binary"
238                    ))
239                }
240            }
241            Err(e) => status_err!(Status::internal(&format!(
242                "error on the websocket channel {:?}",
243                e
244            ))),
245        };
246
247        // Parse message first into protobuf then into http request
248        let call = match Call::decode(msg) {
249            Ok(call) => call,
250            Err(e) => status_err!(Status::internal(&format!("failed to decode call {:?}", e))),
251        };
252        let call = webtonic_proto::call_to_http_request(call).unwrap();
253
254        // Get the path to the requested service
255        let path: &str = call
256            .uri()
257            .path()
258            .split("/")
259            .collect::<Vec<&str>>()
260            .get(1)
261            .unwrap_or(&&"/");
262        log::debug!("request to path {:?}", path);
263
264        let mut response = match routes.root.clone().call((path.to_string(), call)).await {
265            Ok(response) => response,
266            Err(_e) => {
267                panic!("Tonic services never error");
268            }
269        };
270        log::debug!("got response {:?}", response);
271
272        // Turn reply first into protobuf, then into message
273        let reply = webtonic_proto::http_response_to_reply(&mut response).await;
274        let mut msg = BytesMut::new();
275        match reply.encode(&mut msg) {
276            Ok(()) => (),
277            Err(e) => status_err!(Status::internal(&format!("failed to decode reply {:?}", e))),
278        };
279        let msg = Message::binary(msg.as_ref());
280
281        // Return the message
282        log::debug!("sending response {:?}", msg);
283        match tx.send(Ok(msg)) {
284            Ok(()) => (),
285            Err(e) => {
286                log::warn!("stream no longer exists {:?}", e);
287                break;
288            }
289        }
290    }
291}
292
293async fn return_status(tx: &UnboundedSender<Result<Message, warp::Error>>, status: Status) -> bool {
294    log::warn!("error while processing msg, returning status {:?}", status);
295    let mut response = status.to_http();
296
297    let reply = webtonic_proto::http_response_to_reply(&mut response).await;
298    let mut msg = BytesMut::new();
299    reply.encode(&mut msg).unwrap();
300    let msg = Message::binary(msg.as_ref());
301
302    match tx.send(Ok(msg)) {
303        Ok(()) => true,
304        Err(_) => false,
305    }
306}