1use notify::Watcher;
2use std::{convert::Infallible, path::PathBuf, sync::Arc};
3use tokio::sync::{Mutex, Notify};
4use tokio_stream::StreamExt;
5use which::which;
6
7pub struct Config {
8 pub host: std::net::IpAddr,
9 pub port: u16,
10 pub child_host: std::net::IpAddr,
11 pub child_port: u16,
12 pub watch_paths: Vec<PathBuf>,
13 pub ignore_paths: Vec<PathBuf>,
14 pub command: String,
15}
16
17pub async fn run(config: Config) {
18 let Config {
19 host,
20 port,
21 child_host,
22 child_port,
23 watch_paths,
24 ignore_paths,
25 command,
26 } = config;
27 let addr = std::net::SocketAddr::new(host, port);
28 let child_addr = std::net::SocketAddr::new(child_host, child_port);
29 let cwd = std::env::current_dir().unwrap();
30 let watch_paths: Vec<PathBuf> = watch_paths.into_iter().map(|path| cwd.join(path)).collect();
31 let ignore_paths: Vec<PathBuf> = ignore_paths
32 .into_iter()
33 .map(|path| cwd.join(path))
34 .collect();
35
36 enum State {
37 Ground,
38 Building {
39 notify: Arc<Notify>,
40 child: Option<std::process::Child>,
41 },
42 Running {
43 child: Option<std::process::Child>,
44 },
45 }
46 let state: Arc<Mutex<State>> = Arc::new(Mutex::new(State::Ground));
47 let (watch_events_tx, watch_events_rx) = tokio::sync::mpsc::unbounded_channel();
48 watch_events_tx.send(()).unwrap();
49
50 let mut watcher = notify::recommended_watcher(move |_: notify::Result<notify::Event>| {
52 watch_events_tx.send(()).unwrap();
53 })
54 .unwrap();
55 let mut walk_builder = ignore::WalkBuilder::new(watch_paths.first().unwrap());
56 for watch_path in watch_paths.iter().skip(1) {
57 walk_builder.add(watch_path);
58 }
59 walk_builder.filter_entry(move |entry| {
60 let path = entry.path();
61 let ignored = ignore_paths
62 .iter()
63 .any(|ignore_path| path.starts_with(ignore_path));
64 !ignored
65 });
66 let walk = walk_builder.build();
67 for entry in walk {
68 let entry = entry.unwrap();
69 let path = entry.path();
70 watcher
71 .watch(path, notify::RecursiveMode::NonRecursive)
72 .unwrap();
73 }
74
75 tokio::spawn({
76 let state = state.clone();
77 async move {
78 let watch_events =
79 tokio_stream::wrappers::UnboundedReceiverStream::new(watch_events_rx)
80 .chunks_timeout(1_000_000, std::time::Duration::from_millis(10));
81 tokio::pin!(watch_events);
82 while watch_events.next().await.is_some() {
83 if let State::Running { child } = &mut *state.lock().await {
85 let mut child = child.take().unwrap();
86 child.kill().ok();
87 child.wait().unwrap();
88 }
89 let notify = Arc::new(Notify::new());
91 let sh = which("sh").unwrap();
92 let child = std::process::Command::new(sh)
93 .args(vec!["-c", &command])
94 .env("HOST", &child_host.to_string())
95 .env("PORT", &child_port.to_string())
96 .spawn()
97 .unwrap();
98 *state.lock().await = State::Building {
99 notify: notify.clone(),
100 child: Some(child),
101 };
102 loop {
103 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
104 if let State::Building { child, .. } = &mut *state.lock().await {
105 if let Ok(Some(_)) | Err(_) = child.as_mut().unwrap().try_wait() {
106 break;
107 }
108 }
109 if tokio::net::TcpStream::connect(&child_addr).await.is_ok() {
110 break;
111 }
112 }
113 let child = if let State::Building { child, .. } = &mut *state.lock().await {
114 child.take().unwrap()
115 } else {
116 panic!()
117 };
118 *state.lock().await = State::Running { child: Some(child) };
119 notify.notify_waiters();
120 }
121 }
122 });
123
124 let handler = move |state: Arc<Mutex<State>>, mut request: http::Request<hyper::Body>| async move {
126 let notify = if let State::Building { notify, .. } = &mut *state.lock().await {
127 Some(notify.clone())
128 } else {
129 None
130 };
131 if let Some(notify) = notify {
132 notify.notified().await;
133 }
134 let child_authority = format!("{}:{}", child_host, child_port);
135 let child_authority = http::uri::Authority::from_maybe_shared(child_authority).unwrap();
136 *request.uri_mut() = http::Uri::builder()
137 .scheme("http")
138 .authority(child_authority)
139 .path_and_query(request.uri().path_and_query().unwrap().clone())
140 .build()
141 .unwrap();
142 hyper::Client::new()
143 .request(request)
144 .await
145 .unwrap_or_else(|_| {
146 http::Response::builder()
147 .status(http::StatusCode::SERVICE_UNAVAILABLE)
148 .body(hyper::Body::from("service unavailable"))
149 .unwrap()
150 })
151 };
152
153 let service = hyper::service::make_service_fn(|_| {
155 let state = state.clone();
156 async move {
157 Ok::<_, Infallible>(hyper::service::service_fn(
158 move |request: http::Request<hyper::Body>| {
159 let state = state.clone();
160 async move { Ok::<_, Infallible>(handler(state, request).await) }
161 },
162 ))
163 }
164 });
165 hyper::Server::bind(&addr).serve(service).await.unwrap();
166}