1use hyperium::{Uri, uri::Parts};
2use spin_executor::CancelToken;
3use std::future::Future;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6use tonic::body::Body;
7use tower_service::Service;
8use wasi::io::poll::Pollable;
9use wasi_hyperium::{IncomingHttpBody, hyperium1::send_outbound_request, poll::PollableRegistry};
10
11pub struct WasiGrpcEndpoint {
12 endpoint: Uri,
13}
14
15impl WasiGrpcEndpoint {
16 pub fn new(endpoint: Uri) -> Self {
17 WasiGrpcEndpoint { endpoint }
18 }
19}
20
21impl Service<hyperium::Request<Body>> for WasiGrpcEndpoint {
22 type Response = hyperium::Response<IncomingHttpBody<SpinExecutorPoller>>;
23 type Error = wasi_hyperium::Error;
24 #[allow(clippy::type_complexity)]
25 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
26
27 fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
28 Poll::Ready(Ok(()))
29 }
30
31 fn call(&mut self, mut req: hyperium::Request<Body>) -> Self::Future {
32 let Parts {
33 scheme, authority, ..
34 } = self.endpoint.clone().into_parts();
35
36 let mut parts = std::mem::take(req.uri_mut()).into_parts();
37 parts.authority = authority;
38 parts.scheme = scheme;
39
40 *req.uri_mut() = parts.try_into().unwrap();
41
42 Box::pin(send_outbound_request(req, SpinExecutorPoller))
43 }
44}
45
46#[derive(Clone)]
47pub struct SpinExecutorPoller;
48
49impl PollableRegistry for SpinExecutorPoller {
50 type RegisteredPollable = CancelToken;
51
52 fn register_pollable(&self, cx: &mut Context, pollable: Pollable) -> Self::RegisteredPollable {
53 spin_executor::push_waker_and_get_token(pollable, cx.waker().clone())
54 }
55
56 fn poll(&self) -> bool {
58 panic!("not supported for wasi-grpc")
59 }
60
61 fn block_on<T>(
63 &self,
64 _fut: impl std::future::Future<Output = T>,
65 ) -> Result<T, wasi_hyperium::poll::Stalled> {
66 panic!("not supported for wasi-grpc")
67 }
68}