sidevm_host_runtime/
rocket_stream.rs1use std::pin::Pin;
2
3use anyhow::{anyhow, Result};
4use rocket::{
5 data::{ByteUnit, IoHandler, IoStream},
6 http::Status,
7 request::{FromRequest, Outcome},
8 response::Responder,
9 Data, Request,
10};
11use sidevm_env::messages::{HttpHead, HttpResponseHead};
12use tokio::{
13 io::{split, AsyncWriteExt, DuplexStream},
14 sync::mpsc::Sender as ChannelSender,
15 sync::oneshot::channel as oneshot_channel,
16};
17use tracing::error;
18
19use crate::{service::Command, IncomingHttpRequest};
20
21pub struct RequestInfo {
22 method: String,
23 host: String,
24 query: String,
25 headers: Vec<(String, String)>,
26}
27
28pub struct StreamResponse {
29 head: HttpResponseHead,
30 io_stream: DuplexStream,
31}
32
33impl StreamResponse {
34 pub fn new(head: HttpResponseHead, io_stream: DuplexStream) -> Self {
35 Self { head, io_stream }
36 }
37}
38
39#[rocket::async_trait]
40impl IoHandler for StreamResponse {
41 async fn io(self: Pin<Box<Self>>, io: IoStream) -> std::io::Result<()> {
42 let Self { io_stream, .. } = *Pin::into_inner(self);
43 let (mut server_reader, mut server_writer) = split(io_stream);
44 let (mut client_reader, mut client_writer) = split(io);
45 let (res_c2s, res_s2c) = tokio::join! {
46 tokio::io::copy(&mut client_reader, &mut server_writer),
47 tokio::io::copy(&mut server_reader, &mut client_writer),
48 };
49 if let Err(err) = res_c2s {
50 error!(target: "sidevm", "Failed to copy from client to server: {err}");
51 return Err(err);
52 }
53 if let Err(err) = res_s2c {
54 error!(target: "sidevm", "Failed to copy from server to client: {err}");
55 return Err(err);
56 }
57 Ok(())
58 }
59}
60
61impl<'r> Responder<'r, 'r> for StreamResponse {
62 fn respond_to(mut self, _req: &'r Request<'_>) -> rocket::response::Result<'r> {
63 let mut builder = rocket::response::Response::build();
64 self.head
65 .headers
66 .retain(|(name, _)| name.to_lowercase() != "set-cookie");
67 if Status::new(self.head.status) == Status::SwitchingProtocols {
68 builder.status(Status::ServiceUnavailable);
71 let mut protocol = String::new();
72 for (name, value) in self.head.headers.drain(..) {
73 let name = name.to_lowercase();
74 if name == "upgrade" {
75 protocol = value.to_string();
76 }
77 if name != "connection" && name != "upgrade" {
78 builder.raw_header_adjoin(name, value);
79 }
80 }
81 builder.upgrade(protocol, self);
82 builder.streamed_body(&[] as &[u8]);
83 } else {
84 builder.status(Status::new(self.head.status));
85 for (name, value) in self.head.headers.into_iter() {
86 builder.raw_header_adjoin(name, value);
87 }
88 builder.streamed_body(self.io_stream);
89 }
90 Ok(builder.finalize())
91 }
92}
93
94#[rocket::async_trait]
95impl<'r> FromRequest<'r> for RequestInfo {
96 type Error = &'static str;
97
98 async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
99 let method = req.method().to_string();
100 let uri = req.uri();
101 let query = uri.query().map(|s| s.to_string()).unwrap_or_default();
102 let host = req.host().map(|s| s.to_string()).unwrap_or_default();
103 let headers = req
104 .headers()
105 .iter()
106 .filter_map(|header| {
107 if header.name.as_uncased_str() == "cookie" {
108 None
109 } else {
110 Some((header.name.to_string(), header.value.to_string()))
111 }
112 })
113 .collect();
114 Outcome::Success(Self {
115 method,
116 host,
117 query,
118 headers,
119 })
120 }
121}
122
123impl RequestInfo {
124 fn into_head(self, path: &str) -> HttpHead {
125 let Self {
126 method,
127 host,
128 query,
129 headers,
130 } = self;
131 let mut url = format!("http://{}/{}", host, path);
132 if !query.is_empty() {
133 url.push('?');
134 url.push_str(&query);
135 }
136 HttpHead {
137 method,
138 url,
139 headers,
140 }
141 }
142}
143
144fn is_upgrade_request(req: &RequestInfo) -> bool {
145 req.headers
146 .iter()
147 .find_map(|(name, value)| {
148 if name.to_lowercase() == "connection" {
149 Some(value.to_lowercase() == "upgrade")
150 } else {
151 None
152 }
153 })
154 .unwrap_or(false)
155}
156
157pub async fn connect(
158 head: RequestInfo,
159 path: &str,
160 body: Option<Data<'_>>,
161 command_tx: ChannelSender<Command>,
162) -> Result<StreamResponse> {
163 let is_upgrade = is_upgrade_request(&head);
164 let (response_tx, response_rx) = oneshot_channel();
165 let (mut stream0, stream1) = tokio::io::duplex(1024);
166 let command = Command::HttpRequest(IncomingHttpRequest {
167 head: head.into_head(path),
168 body_stream: stream1,
169 response_tx,
170 });
171 command_tx
172 .send(command)
173 .await
174 .or(Err(anyhow!("Command channel closed")))?;
175 if !is_upgrade {
176 if let Some(body) = body {
178 let data_stream = body.open(ByteUnit::max_value());
179 let stream0 = &mut stream0;
180 let result: Result<()> = async move {
181 data_stream.stream_to(&mut *stream0).await?;
182 stream0
183 .shutdown()
184 .await
185 .or(Err(anyhow!("Stream shutdown error")))?;
186 Ok(())
187 }
188 .await;
189 if let Err(err) = result {
190 error!(target: "sidevm", "Failed to pipe the body: {err:?}");
191 }
192 }
193 }
194 let resposne = response_rx
195 .await
196 .map_err(|_| anyhow!("Response channel closed"))??;
197 Ok(StreamResponse::new(resposne, stream0))
198}