sidevm_host_runtime/
rocket_stream.rs

1use std::pin::Pin;
2
3use anyhow::{anyhow, Result};
4use rocket::{
5    data::{ByteUnit, IoHandler, IoStream},
6    http::Status,
7    request::{FromRequest, Outcome},
8    response::Responder,
9    Data, Request,
10};
11use sidevm_env::messages::{HttpHead, HttpResponseHead};
12use tokio::{
13    io::{split, AsyncWriteExt, DuplexStream},
14    sync::mpsc::Sender as ChannelSender,
15    sync::oneshot::channel as oneshot_channel,
16};
17use tracing::error;
18
19use crate::{service::Command, IncomingHttpRequest};
20
21pub struct RequestInfo {
22    method: String,
23    host: String,
24    query: String,
25    headers: Vec<(String, String)>,
26}
27
28pub struct StreamResponse {
29    head: HttpResponseHead,
30    io_stream: DuplexStream,
31}
32
33impl StreamResponse {
34    pub fn new(head: HttpResponseHead, io_stream: DuplexStream) -> Self {
35        Self { head, io_stream }
36    }
37}
38
39#[rocket::async_trait]
40impl IoHandler for StreamResponse {
41    async fn io(self: Pin<Box<Self>>, io: IoStream) -> std::io::Result<()> {
42        let Self { io_stream, .. } = *Pin::into_inner(self);
43        let (mut server_reader, mut server_writer) = split(io_stream);
44        let (mut client_reader, mut client_writer) = split(io);
45        let (res_c2s, res_s2c) = tokio::join! {
46            tokio::io::copy(&mut client_reader, &mut server_writer),
47            tokio::io::copy(&mut server_reader, &mut client_writer),
48        };
49        if let Err(err) = res_c2s {
50            error!(target: "sidevm", "Failed to copy from client to server: {err}");
51            return Err(err);
52        }
53        if let Err(err) = res_s2c {
54            error!(target: "sidevm", "Failed to copy from server to client: {err}");
55            return Err(err);
56        }
57        Ok(())
58    }
59}
60
61impl<'r> Responder<'r, 'r> for StreamResponse {
62    fn respond_to(mut self, _req: &'r Request<'_>) -> rocket::response::Result<'r> {
63        let mut builder = rocket::response::Response::build();
64        self.head
65            .headers
66            .retain(|(name, _)| name.to_lowercase() != "set-cookie");
67        if Status::new(self.head.status) == Status::SwitchingProtocols {
68            // As Rocket requires to not set status to 101 and do not set headers 'Connection', 'Upgrade',
69            // we need to remove them from the response header.
70            builder.status(Status::ServiceUnavailable);
71            let mut protocol = String::new();
72            for (name, value) in self.head.headers.drain(..) {
73                let name = name.to_lowercase();
74                if name == "upgrade" {
75                    protocol = value.to_string();
76                }
77                if name != "connection" && name != "upgrade" {
78                    builder.raw_header_adjoin(name, value);
79                }
80            }
81            builder.upgrade(protocol, self);
82            builder.streamed_body(&[] as &[u8]);
83        } else {
84            builder.status(Status::new(self.head.status));
85            for (name, value) in self.head.headers.into_iter() {
86                builder.raw_header_adjoin(name, value);
87            }
88            builder.streamed_body(self.io_stream);
89        }
90        Ok(builder.finalize())
91    }
92}
93
94#[rocket::async_trait]
95impl<'r> FromRequest<'r> for RequestInfo {
96    type Error = &'static str;
97
98    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
99        let method = req.method().to_string();
100        let uri = req.uri();
101        let query = uri.query().map(|s| s.to_string()).unwrap_or_default();
102        let host = req.host().map(|s| s.to_string()).unwrap_or_default();
103        let headers = req
104            .headers()
105            .iter()
106            .filter_map(|header| {
107                if header.name.as_uncased_str() == "cookie" {
108                    None
109                } else {
110                    Some((header.name.to_string(), header.value.to_string()))
111                }
112            })
113            .collect();
114        Outcome::Success(Self {
115            method,
116            host,
117            query,
118            headers,
119        })
120    }
121}
122
123impl RequestInfo {
124    fn into_head(self, path: &str) -> HttpHead {
125        let Self {
126            method,
127            host,
128            query,
129            headers,
130        } = self;
131        let mut url = format!("http://{}/{}", host, path);
132        if !query.is_empty() {
133            url.push('?');
134            url.push_str(&query);
135        }
136        HttpHead {
137            method,
138            url,
139            headers,
140        }
141    }
142}
143
144fn is_upgrade_request(req: &RequestInfo) -> bool {
145    req.headers
146        .iter()
147        .find_map(|(name, value)| {
148            if name.to_lowercase() == "connection" {
149                Some(value.to_lowercase() == "upgrade")
150            } else {
151                None
152            }
153        })
154        .unwrap_or(false)
155}
156
157pub async fn connect(
158    head: RequestInfo,
159    path: &str,
160    body: Option<Data<'_>>,
161    command_tx: ChannelSender<Command>,
162) -> Result<StreamResponse> {
163    let is_upgrade = is_upgrade_request(&head);
164    let (response_tx, response_rx) = oneshot_channel();
165    let (mut stream0, stream1) = tokio::io::duplex(1024);
166    let command = Command::HttpRequest(IncomingHttpRequest {
167        head: head.into_head(path),
168        body_stream: stream1,
169        response_tx,
170    });
171    command_tx
172        .send(command)
173        .await
174        .or(Err(anyhow!("Command channel closed")))?;
175    if !is_upgrade {
176        // If it is a vanilla HTTP request, we need to send the body.
177        if let Some(body) = body {
178            let data_stream = body.open(ByteUnit::max_value());
179            let stream0 = &mut stream0;
180            let result: Result<()> = async move {
181                data_stream.stream_to(&mut *stream0).await?;
182                stream0
183                    .shutdown()
184                    .await
185                    .or(Err(anyhow!("Stream shutdown error")))?;
186                Ok(())
187            }
188            .await;
189            if let Err(err) = result {
190                error!(target: "sidevm", "Failed to pipe the body: {err:?}");
191            }
192        }
193    }
194    let resposne = response_rx
195        .await
196        .map_err(|_| anyhow!("Response channel closed"))??;
197    Ok(StreamResponse::new(resposne, stream0))
198}