vertigo_cli/serve/
serve_run.rs1use axum::{
2 body::BoxBody,
3 extract::{Json, RawQuery, State},
4 http::{header::HeaderMap, HeaderValue, StatusCode, Uri},
5 response::Response,
6 routing::get,
7 Router,
8};
9use reqwest::header;
10use serde_json::Value;
11use std::{
12 sync::Arc,
13 time::{Duration, Instant},
14};
15use tokio::sync::{OnceCell, RwLock};
16use tower_http::services::ServeDir;
17
18use crate::serve::mount_path::MountPathConfig;
19use crate::{
20 commons::ErrorCode,
21 serve::{server_state::ServerState, ServeOptsInner},
22};
23
24use super::ServeOpts;
25
26static STATE: OnceCell<Arc<RwLock<Arc<ServerState>>>> = OnceCell::const_new();
27
28pub async fn run(opts: ServeOpts, port_watch: Option<u16>) -> Result<(), ErrorCode> {
29 log::info!("serve params => {opts:#?}");
30
31 let ServeOptsInner {
32 host,
33 port,
34 proxy,
35 env,
36 } = opts.inner;
37
38 let mount_path = MountPathConfig::new(opts.common.dest_dir)?;
39 let state = Arc::new(ServerState::new(mount_path, port_watch, env)?);
40
41 let ref_state = STATE
42 .get_or_init({
43 let state = state.clone();
44
45 move || Box::pin(async move { Arc::new(RwLock::new(state)) })
46 })
47 .await;
48
49 let serve_mount_path = state.mount_path.http_root();
50
51 let serve_dir = ServeDir::new(state.mount_path.fs_root());
52
53 *(ref_state.write().await) = state;
54
55 let mut app = Router::new()
56 .nest_service(&serve_mount_path, serve_dir)
57 .layer(axum::middleware::map_response(set_cache_header));
58
59 for (path, target) in proxy {
60 app = install_proxy(app, path, target, ref_state.clone());
61 }
62
63 let app = app.fallback(handler).with_state(ref_state.clone());
64
65 let Ok(addr) = format!("{host}:{port}").parse() else {
66 log::error!("Incorrect listening address");
67 return Err(ErrorCode::ServeCantOpenPort);
68 };
69
70 let ret = axum::Server::bind(&addr)
71 .serve(app.into_make_service())
72 .await;
73
74 if let Err(err) = ret {
75 log::error!("Can't bind/serve on {addr}: {err}");
76 Err(ErrorCode::ServeCantOpenPort)
77 } else {
78 log::info!("Listening on http://{}", addr);
79 Ok(())
80 }
81}
82
83async fn get_response(target_url: String) -> Response<BoxBody> {
84 let response = match reqwest::get(target_url.clone()).await {
85 Ok(response) => response,
86 Err(error) => {
87 let message = format!("Error fetching from url={target_url} error={error}");
88
89 let mut response = message.into_response();
90 *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
91
92 return response;
93 }
94 };
95
96 let headers = response.headers().clone();
97 let status = response.status();
98 let body = match response.bytes().await {
99 Ok(body) => body.to_vec(),
100 Err(error) => {
101 let message = format!("Error fetching body from url={target_url} error={error}");
102
103 let mut response = message.into_response();
104 *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
105
106 return response;
107 }
108 };
109
110 use axum::response::IntoResponse;
111 let mut response: Response<BoxBody> = body.into_response();
112
113 *response.headers_mut() = headers;
114 *response.status_mut() = status;
115
116 response
117}
118
119async fn post_response(target_url: String, headers: HeaderMap, body: Value) -> Response<BoxBody> {
120 let client = reqwest::Client::new();
121
122 let Ok(body) = serde_json::to_vec(&body)
123 .inspect_err(|err| log::error!("Error serializing request body: {err}"))
124 else {
125 let mut resp = Response::default();
126 *resp.status_mut() = StatusCode::from_u16(600).unwrap_or_default();
127 return resp;
128 };
129
130 let response = match client
131 .post(target_url.clone())
132 .headers(headers)
133 .body(body)
134 .send()
135 .await
136 {
137 Ok(response) => response,
138 Err(error) => {
139 let message = format!("Error fetching from url={target_url} error={error}");
140
141 let mut response = message.into_response();
142 *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
143
144 return response;
145 }
146 };
147
148 let headers = response.headers().clone();
149 let status = response.status();
150 let body = match response.bytes().await {
151 Ok(body) => body.to_vec(),
152 Err(error) => {
153 let message = format!("Error fetching body from url={target_url} error={error}");
154
155 let mut response = message.into_response();
156 *response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
157
158 return response;
159 }
160 };
161
162 use axum::response::IntoResponse;
163 let mut response: Response<BoxBody> = body.into_response();
164
165 *response.headers_mut() = headers;
166 *response.status_mut() = status;
167
168 response
169}
170
171fn install_proxy(
172 app: Router<Arc<RwLock<Arc<ServerState>>>>,
173 path: String,
174 target: String,
175 ref_state: Arc<RwLock<Arc<ServerState>>>,
176) -> Router<Arc<RwLock<Arc<ServerState>>>> {
177 let router = Router::new()
178 .fallback(
179 get({
180 let path = path.clone();
181 let target = target.clone();
182
183 move |url: Uri| async move {
184 let from_url = format!("{path}{url}");
185 let target_url = format!("{target}{url}");
186 log::info!("proxy get {from_url} -> {target_url}");
187
188 get_response(target_url).await
189 }
190 })
191 .post({
192 let path = path.clone();
193
194 move |url: Uri, headers: HeaderMap, body: Json<Value>| async move {
195 let from_url = format!("{path}{url}");
196 let target_url = format!("{target}{url}");
197 let Json(body) = body;
198 log::info!("proxy post {from_url} -> {target_url}");
199
200 post_response(target_url, headers, body).await
201 }
202 }),
203 )
204 .with_state(ref_state);
205
206 app.nest_service(path.as_str(), router)
207}
208
209#[axum::debug_handler]
210async fn handler(
211 url: Uri,
212 RawQuery(query): RawQuery,
213 State(state): State<Arc<RwLock<Arc<ServerState>>>>,
214) -> Response<String> {
215 let state = state.read().await.clone();
216
217 let now = Instant::now();
218 let uri = {
219 let url = url.path();
220
221 match query {
222 Some(query) => format!("{url}?{query}"),
223 None => url.to_string(),
224 }
225 };
226
227 log::debug!("Incoming request: {uri}");
228 let mut response_state = state.request(&uri).await;
229
230 let time = now.elapsed();
231 log::log!(
232 if time > Duration::from_secs(1) {
233 log::Level::Warn
234 } else {
235 log::Level::Info
236 },
237 "Response for request: {} {}ms {url}",
238 response_state.status,
239 time.as_millis()
240 );
241
242 if let Some(port_watch) = state.port_watch {
243 response_state.add_watch_script(port_watch);
244 }
245
246 if response_state.status.is_server_error() {
247 log::error!("WASM status: {}", response_state.status);
248 log::error!("WASM response: {}", response_state.body);
249 }
250
251 response_state.into()
252}
253
254async fn set_cache_header<B: Send>(mut response: Response<B>) -> Response<B> {
255 response.headers_mut().insert(
256 header::CACHE_CONTROL,
257 HeaderValue::from_static("public, max-age=31536000"),
258 );
259 response
260}