1use bytes::{Bytes, BytesMut};
9use core::{
10 marker::{Send, Sync},
11 task::Context,
12 task::Poll,
13};
14use futures::{future, StreamExt};
15use http::{request::Request, response::Response};
16use prost::Message as ProstMessage;
17use std::net::SocketAddr;
18use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
19use tokio_stream::wrappers::UnboundedReceiverStream;
20use tonic::{body::BoxBody, codegen::Never, transport::NamedService, Status};
21use tower_service::Service;
22use warp::{
23 ws::{Message, WebSocket},
24 Filter,
25};
26use webtonic_proto::Call;
27
28#[derive(Debug, Clone)]
46pub struct Server {}
47
48impl Server {
49 pub fn builder() -> Self {
54 Self {}
55 }
56
57 pub fn add_service<A>(self, service: A) -> Router<A, Unimplemented>
67 where
68 A: Service<Request<BoxBody>, Response = Response<BoxBody>> + Sync + Send + 'static,
69 {
70 Router {
71 server: self,
72 root: Route(service, Unimplemented),
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct Router<A, B> {
80 server: Server,
81 root: Route<A, B>,
82}
83
84impl<A, B> Router<A, B> {
85 pub fn add_service<C>(self, service: C) -> Router<C, Route<A, B>>
94 where
95 C: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Never>,
96 {
97 Router {
98 server: self.server,
99 root: Route(service, self.root),
100 }
101 }
102
103 pub async fn serve<U>(self, addr: U)
111 where
112 U: Into<SocketAddr>,
113 A: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Never>
114 + NamedService
115 + Clone
116 + Send
117 + Sync
118 + 'static,
119 A::Future: Send + 'static,
120 B: Service<(String, Request<BoxBody>), Response = Response<BoxBody>, Error = Never>
121 + Clone
122 + Send
123 + Sync
124 + 'static,
125 B::Future: Send + 'static,
126 {
127 let server_clone = warp::any().map(move || self.clone());
128
129 warp::serve(warp::path::end().and(warp::ws()).and(server_clone).map(
130 |ws: warp::ws::Ws, server_clone| {
131 ws.on_upgrade(|socket| handle_connection2(socket, server_clone))
132 },
133 ))
134 .run(addr)
135 .await;
136 }
137}
138
139#[derive(Debug, Clone)]
144pub struct Route<A, B>(A, B);
145
146impl<A, B> Service<(String, Request<BoxBody>)> for Route<A, B>
147where
148 A: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Never> + NamedService,
149 A::Future: Send + 'static,
150 B: Service<(String, Request<BoxBody>), Response = Response<BoxBody>, Error = Never>,
151 B::Future: Send + 'static,
152{
153 type Response = Response<BoxBody>;
154 type Error = Never;
155 type Future = future::Either<A::Future, B::Future>;
156
157 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
158 Ok(()).into()
159 }
160
161 fn call(&mut self, req: (String, Request<BoxBody>)) -> Self::Future {
162 if req.0.eq(<A as NamedService>::NAME) {
163 future::Either::Left(self.0.call(req.1))
164 } else {
165 future::Either::Right(self.1.call((req.0, req.1)))
166 }
167 }
168}
169
170#[derive(Default, Clone, Debug)]
174pub struct Unimplemented;
175
176impl Service<(String, Request<BoxBody>)> for Unimplemented {
177 type Response = Response<BoxBody>;
178 type Error = Never;
179 type Future = future::Ready<Result<Self::Response, Self::Error>>;
180
181 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
182 Ok(()).into()
183 }
184
185 fn call(&mut self, _req: (String, Request<BoxBody>)) -> Self::Future {
186 future::ok(
187 http::Response::builder()
188 .status(200)
189 .header("grpc-status", "12")
190 .header("content-type", "application/grpc")
191 .body(BoxBody::empty())
192 .unwrap(),
193 )
194 }
195}
196
197async fn handle_connection2<A, B>(ws: WebSocket, routes: Router<A, B>)
198where
199 A: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Never>
200 + NamedService
201 + Clone,
202 A::Future: Send + 'static,
203 B: Service<(String, Request<BoxBody>), Response = Response<BoxBody>, Error = Never> + Clone,
204 B::Future: Send + 'static,
205{
206 log::debug!("opening a new connection");
207
208 let (ws_tx, mut ws_rx) = ws.split();
209 let (tx, rx) = unbounded_channel();
210 tokio::task::spawn(UnboundedReceiverStream::new(rx).forward(ws_tx));
212
213 while let Some(msg) = ws_rx.next().await {
214 log::debug!("received message {:?}", msg);
215
216 macro_rules! status_err {
219 ($status: expr) => {
220 match return_status(&tx, $status).await {
221 true => continue,
222 false => break,
223 }
224 };
225 }
226
227 let msg = match msg {
229 Ok(msg) => {
230 if msg.is_binary() {
231 Bytes::from(msg.into_bytes())
232 } else if msg.is_close() {
233 log::debug!("channel was closed");
234 break;
235 } else {
236 status_err!(Status::invalid_argument(
237 "websocket messages must be sent in binary"
238 ))
239 }
240 }
241 Err(e) => status_err!(Status::internal(&format!(
242 "error on the websocket channel {:?}",
243 e
244 ))),
245 };
246
247 let call = match Call::decode(msg) {
249 Ok(call) => call,
250 Err(e) => status_err!(Status::internal(&format!("failed to decode call {:?}", e))),
251 };
252 let call = webtonic_proto::call_to_http_request(call).unwrap();
253
254 let path: &str = call
256 .uri()
257 .path()
258 .split("/")
259 .collect::<Vec<&str>>()
260 .get(1)
261 .unwrap_or(&&"/");
262 log::debug!("request to path {:?}", path);
263
264 let mut response = match routes.root.clone().call((path.to_string(), call)).await {
265 Ok(response) => response,
266 Err(_e) => {
267 panic!("Tonic services never error");
268 }
269 };
270 log::debug!("got response {:?}", response);
271
272 let reply = webtonic_proto::http_response_to_reply(&mut response).await;
274 let mut msg = BytesMut::new();
275 match reply.encode(&mut msg) {
276 Ok(()) => (),
277 Err(e) => status_err!(Status::internal(&format!("failed to decode reply {:?}", e))),
278 };
279 let msg = Message::binary(msg.as_ref());
280
281 log::debug!("sending response {:?}", msg);
283 match tx.send(Ok(msg)) {
284 Ok(()) => (),
285 Err(e) => {
286 log::warn!("stream no longer exists {:?}", e);
287 break;
288 }
289 }
290 }
291}
292
293async fn return_status(tx: &UnboundedSender<Result<Message, warp::Error>>, status: Status) -> bool {
294 log::warn!("error while processing msg, returning status {:?}", status);
295 let mut response = status.to_http();
296
297 let reply = webtonic_proto::http_response_to_reply(&mut response).await;
298 let mut msg = BytesMut::new();
299 reply.encode(&mut msg).unwrap();
300 let msg = Message::binary(msg.as_ref());
301
302 match tx.send(Ok(msg)) {
303 Ok(()) => true,
304 Err(_) => false,
305 }
306}