sunfish/
watchserve.rs

1use notify::Watcher;
2use std::{convert::Infallible, path::PathBuf, sync::Arc};
3use tokio::sync::{Mutex, Notify};
4use tokio_stream::StreamExt;
5use which::which;
6
7pub struct Config {
8	pub host: std::net::IpAddr,
9	pub port: u16,
10	pub child_host: std::net::IpAddr,
11	pub child_port: u16,
12	pub watch_paths: Vec<PathBuf>,
13	pub ignore_paths: Vec<PathBuf>,
14	pub command: String,
15}
16
17pub async fn run(config: Config) {
18	let Config {
19		host,
20		port,
21		child_host,
22		child_port,
23		watch_paths,
24		ignore_paths,
25		command,
26	} = config;
27	let addr = std::net::SocketAddr::new(host, port);
28	let child_addr = std::net::SocketAddr::new(child_host, child_port);
29	let cwd = std::env::current_dir().unwrap();
30	let watch_paths: Vec<PathBuf> = watch_paths.into_iter().map(|path| cwd.join(path)).collect();
31	let ignore_paths: Vec<PathBuf> = ignore_paths
32		.into_iter()
33		.map(|path| cwd.join(path))
34		.collect();
35
36	enum State {
37		Ground,
38		Building {
39			notify: Arc<Notify>,
40			child: Option<std::process::Child>,
41		},
42		Running {
43			child: Option<std::process::Child>,
44		},
45	}
46	let state: Arc<Mutex<State>> = Arc::new(Mutex::new(State::Ground));
47	let (watch_events_tx, watch_events_rx) = tokio::sync::mpsc::unbounded_channel();
48	watch_events_tx.send(()).unwrap();
49
50	// Run the file watcher.
51	let mut watcher = notify::recommended_watcher(move |_: notify::Result<notify::Event>| {
52		watch_events_tx.send(()).unwrap();
53	})
54	.unwrap();
55	let mut walk_builder = ignore::WalkBuilder::new(watch_paths.first().unwrap());
56	for watch_path in watch_paths.iter().skip(1) {
57		walk_builder.add(watch_path);
58	}
59	walk_builder.filter_entry(move |entry| {
60		let path = entry.path();
61		let ignored = ignore_paths
62			.iter()
63			.any(|ignore_path| path.starts_with(ignore_path));
64		!ignored
65	});
66	let walk = walk_builder.build();
67	for entry in walk {
68		let entry = entry.unwrap();
69		let path = entry.path();
70		watcher
71			.watch(path, notify::RecursiveMode::NonRecursive)
72			.unwrap();
73	}
74
75	tokio::spawn({
76		let state = state.clone();
77		async move {
78			let watch_events =
79				tokio_stream::wrappers::UnboundedReceiverStream::new(watch_events_rx)
80					.chunks_timeout(1_000_000, std::time::Duration::from_millis(10));
81			tokio::pin!(watch_events);
82			while watch_events.next().await.is_some() {
83				// Kill the previous child process if any.
84				if let State::Running { child } = &mut *state.lock().await {
85					let mut child = child.take().unwrap();
86					child.kill().ok();
87					child.wait().unwrap();
88				}
89				// Start the new process.
90				let notify = Arc::new(Notify::new());
91				let sh = which("sh").unwrap();
92				let child = std::process::Command::new(sh)
93					.args(vec!["-c", &command])
94					.env("HOST", &child_host.to_string())
95					.env("PORT", &child_port.to_string())
96					.spawn()
97					.unwrap();
98				*state.lock().await = State::Building {
99					notify: notify.clone(),
100					child: Some(child),
101				};
102				loop {
103					tokio::time::sleep(std::time::Duration::from_millis(100)).await;
104					if let State::Building { child, .. } = &mut *state.lock().await {
105						if let Ok(Some(_)) | Err(_) = child.as_mut().unwrap().try_wait() {
106							break;
107						}
108					}
109					if tokio::net::TcpStream::connect(&child_addr).await.is_ok() {
110						break;
111					}
112				}
113				let child = if let State::Building { child, .. } = &mut *state.lock().await {
114					child.take().unwrap()
115				} else {
116					panic!()
117				};
118				*state.lock().await = State::Running { child: Some(child) };
119				notify.notify_waiters();
120			}
121		}
122	});
123
124	// Handle requests by waiting for a build to finish if one is in progress, then proxying the request to the child process.
125	let handler = move |state: Arc<Mutex<State>>, mut request: http::Request<hyper::Body>| async move {
126		let notify = if let State::Building { notify, .. } = &mut *state.lock().await {
127			Some(notify.clone())
128		} else {
129			None
130		};
131		if let Some(notify) = notify {
132			notify.notified().await;
133		}
134		let child_authority = format!("{}:{}", child_host, child_port);
135		let child_authority = http::uri::Authority::from_maybe_shared(child_authority).unwrap();
136		*request.uri_mut() = http::Uri::builder()
137			.scheme("http")
138			.authority(child_authority)
139			.path_and_query(request.uri().path_and_query().unwrap().clone())
140			.build()
141			.unwrap();
142		hyper::Client::new()
143			.request(request)
144			.await
145			.unwrap_or_else(|_| {
146				http::Response::builder()
147					.status(http::StatusCode::SERVICE_UNAVAILABLE)
148					.body(hyper::Body::from("service unavailable"))
149					.unwrap()
150			})
151	};
152
153	// Start the server.
154	let service = hyper::service::make_service_fn(|_| {
155		let state = state.clone();
156		async move {
157			Ok::<_, Infallible>(hyper::service::service_fn(
158				move |request: http::Request<hyper::Body>| {
159					let state = state.clone();
160					async move { Ok::<_, Infallible>(handler(state, request).await) }
161				},
162			))
163		}
164	});
165	hyper::Server::bind(&addr).serve(service).await.unwrap();
166}