web_dev_server/
startup.rs

1use std::{
2    io::ErrorKind,
3    net::TcpListener,
4    path::{Component, Path, PathBuf},
5};
6
7use actix_files::NamedFile;
8use actix_web::{
9    HttpRequest, HttpResponse, Result as ActixResult,
10    dev::Server,
11    error::{ErrorInternalServerError, ErrorNotFound},
12    web,
13};
14use anyhow::{Context, anyhow};
15use notify::{
16    RecommendedWatcher, RecursiveMode, Watcher,
17    event::{EventKind, ModifyKind, RenameMode},
18    recommended_watcher,
19};
20use tokio::fs;
21use tokio::sync::{broadcast, mpsc};
22use tokio::time::{Duration, sleep};
23
24use crate::{
25    config::{self, DevServerConfig},
26    internal_scope::build_internal_scope,
27};
28
29#[derive(Clone)]
30pub struct AppState {
31    pub base_dir: PathBuf,
32    pub broadcaster: broadcast::Sender<LiveMessage>,
33    pub diff_mode: bool,
34}
35
36#[derive(Clone, Debug, serde::Serialize)]
37#[serde(tag = "type", rename_all = "kebab-case")]
38pub enum LiveMessage {
39    Reload,
40    Diff {
41        path: String,
42        resource: DiffResource,
43    },
44}
45
46#[derive(Clone, Debug, serde::Serialize)]
47#[serde(rename_all = "lowercase")]
48pub enum DiffResource {
49    Html,
50    Css,
51}
52
53pub struct Application {
54    server: Server,
55    port: u16,
56    _watcher: RecommendedWatcher,
57    state: AppState,
58}
59
60impl Application {
61    pub async fn build(config: &DevServerConfig) -> anyhow::Result<Self> {
62        let allow_fallback = config.port == config::DEFAULT_PORT;
63        let (listener, port) = bind_listener(config.port, allow_fallback)?;
64
65        if allow_fallback && port != config.port {
66            println!(
67                "[web-dev-server] port {} in use, switched to {}",
68                config.port, port
69            );
70        }
71
72        let base_dir = resolve_base_dir(&config.base_dir)
73            .with_context(|| format!("failed to resolve base directory {}", config.base_dir))?;
74
75        let (broadcaster, _) = broadcast::channel(64);
76
77        let state = AppState {
78            base_dir: base_dir.clone(),
79            broadcaster: broadcaster.clone(),
80            diff_mode: config.diff_mode,
81        };
82
83        let (watcher, notify_rx) = create_watcher(&state)?;
84        spawn_watcher_loop(state.clone(), notify_rx);
85
86        let server = run(listener, state.clone()).await?;
87
88        Ok(Self {
89            server,
90            port,
91            _watcher: watcher,
92            state,
93        })
94    }
95
96    pub fn port(&self) -> u16 {
97        self.port
98    }
99
100    pub fn base_dir(&self) -> &Path {
101        &self.state.base_dir
102    }
103
104    pub fn diff_mode(&self) -> bool {
105        self.state.diff_mode
106    }
107
108    pub fn primary_url(&self) -> String {
109        format!("http://127.0.0.1:{}", self.port)
110    }
111
112    pub async fn run_until_stopped(self) -> std::io::Result<()> {
113        self.server.await
114    }
115}
116
117fn bind_listener(preferred_port: u16, allow_fallback: bool) -> anyhow::Result<(TcpListener, u16)> {
118    let mut port = preferred_port;
119
120    loop {
121        match TcpListener::bind(("127.0.0.1", port)) {
122            Ok(listener) => return Ok((listener, port)),
123            Err(error) if allow_fallback && error.kind() == ErrorKind::AddrInUse => {
124                if port == u16::MAX {
125                    return Err(anyhow!(
126                        "failed to find an available port starting at {}",
127                        preferred_port
128                    ));
129                }
130                port = port.checked_add(1).ok_or_else(|| {
131                    anyhow!(
132                        "failed to find an available port starting at {}",
133                        preferred_port
134                    )
135                })?;
136            }
137            Err(error) => {
138                return Err(anyhow::Error::from(error)
139                    .context(format!("failed to bind to 127.0.0.1:{port}")));
140            }
141        }
142    }
143}
144
145async fn run(listener: TcpListener, state: AppState) -> anyhow::Result<Server> {
146    let shared_state = web::Data::new(state);
147
148    let server = actix_web::HttpServer::new(move || {
149        actix_web::App::new()
150            .app_data(shared_state.clone())
151            .service(build_internal_scope())
152            .service(web::resource("/{tail:.*}").route(web::to(serve_file)))
153    })
154    .listen(listener)?
155    .run();
156
157    Ok(server)
158}
159
160fn resolve_base_dir(base_dir: &str) -> anyhow::Result<PathBuf> {
161    let path = PathBuf::from(base_dir);
162    let absolute = if path.is_absolute() {
163        path
164    } else {
165        std::env::current_dir()?.join(path)
166    };
167
168    let canonical = absolute.canonicalize()?;
169    if canonical.is_dir() {
170        Ok(canonical)
171    } else {
172        anyhow::bail!("base directory must be a directory")
173    }
174}
175
176fn create_watcher(
177    state: &AppState,
178) -> anyhow::Result<(
179    RecommendedWatcher,
180    mpsc::UnboundedReceiver<notify::Result<notify::Event>>,
181)> {
182    let (tx, rx) = mpsc::unbounded_channel();
183    let root = state.base_dir.clone();
184
185    let mut watcher = recommended_watcher(move |res| {
186        let _ = tx.send(res);
187    })?;
188
189    watcher.watch(&root, RecursiveMode::Recursive)?;
190
191    Ok((watcher, rx))
192}
193
194fn spawn_watcher_loop(
195    state: AppState,
196    mut rx: mpsc::UnboundedReceiver<notify::Result<notify::Event>>,
197) {
198    tokio::spawn(async move {
199        while let Some(event) = rx.recv().await {
200            match event {
201                Ok(event) => {
202                    let state_for_event = state.clone();
203                    tokio::spawn(async move {
204                        sleep(Duration::from_millis(120)).await;
205                        handle_fs_event(state_for_event, event);
206                    });
207                }
208                Err(error) => {
209                    eprintln!("[web-dev-server] watcher error: {error}");
210                    let _ = state.broadcaster.send(LiveMessage::Reload);
211                }
212            }
213        }
214    });
215}
216
217fn handle_fs_event(state: AppState, event: notify::Event) {
218    if !state.diff_mode {
219        let _ = state.broadcaster.send(LiveMessage::Reload);
220        return;
221    }
222
223    if event.need_rescan() {
224        let _ = state.broadcaster.send(LiveMessage::Reload);
225        return;
226    }
227
228    let kind = event.kind;
229
230    if should_ignore_event(&kind) {
231        return;
232    }
233
234    let mut diff_messages = Vec::new();
235
236    for path in event.paths {
237        if let Some(message) = classify_path(&state, &path) {
238            diff_messages.push(message);
239        }
240    }
241
242    if !diff_messages.is_empty() {
243        if allows_diff(&kind) {
244            for message in diff_messages {
245                let _ = state.broadcaster.send(message);
246            }
247            return;
248        } else {
249            let _ = state.broadcaster.send(LiveMessage::Reload);
250            return;
251        }
252    }
253
254    if should_reload_when_no_diff(&kind) {
255        let _ = state.broadcaster.send(LiveMessage::Reload);
256    }
257}
258
259fn should_ignore_event(kind: &EventKind) -> bool {
260    matches!(
261        kind,
262        EventKind::Access(_) | EventKind::Modify(ModifyKind::Name(RenameMode::From))
263    )
264}
265
266fn allows_diff(kind: &EventKind) -> bool {
267    matches!(
268        kind,
269        EventKind::Create(_)
270            | EventKind::Modify(ModifyKind::Data(_))
271            | EventKind::Modify(ModifyKind::Metadata(_))
272            | EventKind::Modify(ModifyKind::Any)
273            | EventKind::Modify(ModifyKind::Name(
274                RenameMode::To | RenameMode::Both | RenameMode::Any | RenameMode::Other
275            ))
276    )
277}
278
279fn should_reload_when_no_diff(kind: &EventKind) -> bool {
280    match kind {
281        EventKind::Remove(_) | EventKind::Other | EventKind::Any => true,
282        EventKind::Modify(ModifyKind::Name(mode)) => !matches!(mode, RenameMode::From),
283        EventKind::Modify(ModifyKind::Other) => true,
284        _ => false,
285    }
286}
287
288fn classify_path(state: &AppState, path: &Path) -> Option<LiveMessage> {
289    let normalized = normalize_event_path(&state.base_dir, path)?;
290    let ext = normalized.extension()?.to_str()?.to_ascii_lowercase();
291    let resource = match ext.as_str() {
292        "html" | "htm" => DiffResource::Html,
293        "css" => DiffResource::Css,
294        _ => return None,
295    };
296
297    let web_path = to_web_path(&state.base_dir, &normalized, &resource)?;
298
299    Some(LiveMessage::Diff {
300        path: web_path,
301        resource,
302    })
303}
304
305fn normalize_event_path(base_dir: &Path, path: &Path) -> Option<PathBuf> {
306    if let Ok(canonical) = std::fs::canonicalize(path) {
307        return Some(canonical);
308    }
309
310    let resolved = if path.is_absolute() {
311        path.to_path_buf()
312    } else {
313        base_dir.join(path)
314    };
315
316    if let Ok(canonical) = std::fs::canonicalize(&resolved) {
317        Some(canonical)
318    } else {
319        Some(resolved)
320    }
321}
322
323fn to_web_path(base_dir: &Path, path: &Path, resource: &DiffResource) -> Option<String> {
324    let relative = path.strip_prefix(base_dir).ok()?;
325    let mut rel_str = relative.to_string_lossy().replace('\\', "/");
326    if rel_str.is_empty() {
327        return Some(String::from("/"));
328    }
329
330    rel_str = rel_str.trim_start_matches('/').to_owned();
331
332    match resource {
333        DiffResource::Html => {
334            if rel_str.ends_with("index.html") {
335                let prefix = rel_str.trim_end_matches("index.html");
336                if prefix.is_empty() {
337                    Some(String::from("/"))
338                } else {
339                    let trimmed = prefix.trim_end_matches('/');
340                    let mut path = format!("/{}", trimmed);
341                    if !path.ends_with('/') {
342                        path.push('/');
343                    }
344                    Some(path)
345                }
346            } else if rel_str.ends_with("index.htm") {
347                let prefix = rel_str.trim_end_matches("index.htm");
348                if prefix.is_empty() {
349                    Some(String::from("/"))
350                } else {
351                    let trimmed = prefix.trim_end_matches('/');
352                    let mut path = format!("/{}", trimmed);
353                    if !path.ends_with('/') {
354                        path.push('/');
355                    }
356                    Some(path)
357                }
358            } else {
359                Some(format!("/{}", rel_str))
360            }
361        }
362        DiffResource::Css => Some(format!("/{}", rel_str)),
363    }
364}
365
366async fn serve_file(
367    req: HttpRequest,
368    tail: web::Path<String>,
369    state: web::Data<AppState>,
370) -> ActixResult<HttpResponse> {
371    let target = locate_file(&state.base_dir, tail.as_str())
372        .await
373        .map_err(|_| ErrorNotFound("Not Found"))?;
374
375    if is_html(&target) {
376        let raw = fs::read_to_string(&target)
377            .await
378            .map_err(ErrorInternalServerError)?;
379        let injected =
380            inject_live_client(&raw, state.diff_mode).map_err(ErrorInternalServerError)?;
381
382        Ok(HttpResponse::Ok()
383            .append_header(("Cache-Control", "no-cache, no-store, must-revalidate"))
384            .content_type("text/html; charset=utf-8")
385            .body(injected))
386    } else {
387        let file = NamedFile::open_async(&target)
388            .await
389            .map_err(|_| ErrorNotFound("Not Found"))?;
390
391        Ok(file.into_response(&req))
392    }
393}
394
395async fn locate_file(base_dir: &Path, tail: &str) -> anyhow::Result<PathBuf> {
396    let mut full_path = sanitize_path(base_dir, tail)?;
397
398    if let Ok(metadata) = fs::metadata(&full_path).await {
399        if metadata.is_dir() {
400            let index_html = full_path.join("index.html");
401            if fs::metadata(&index_html).await.is_ok() {
402                full_path = index_html;
403            } else {
404                anyhow::bail!("directory has no index.html");
405            }
406        }
407        Ok(full_path)
408    } else {
409        anyhow::bail!("file not found")
410    }
411}
412
413fn sanitize_path(base_dir: &Path, tail: &str) -> anyhow::Result<PathBuf> {
414    let trimmed = tail.trim_start_matches('/');
415    let mut target = PathBuf::from(base_dir);
416
417    if trimmed.is_empty() {
418        target.push("index.html");
419        return Ok(target);
420    }
421
422    let mut has_component = false;
423
424    for component in Path::new(trimmed).components() {
425        match component {
426            Component::Normal(part) => {
427                target.push(part);
428                has_component = true;
429            }
430            Component::CurDir => {}
431            _ => anyhow::bail!("invalid path"),
432        }
433    }
434
435    if !has_component && tail.ends_with('/') {
436        target.push("index.html");
437    }
438
439    Ok(target)
440}
441
442fn is_html(path: &Path) -> bool {
443    path.extension()
444        .and_then(|ext| ext.to_str())
445        .map(|ext| matches!(ext.to_ascii_lowercase().as_str(), "html" | "htm"))
446        .unwrap_or(false)
447}
448
449fn inject_live_client(original: &str, diff_mode: bool) -> anyhow::Result<String> {
450    if original.contains("__web_dev_server_client") {
451        return Ok(original.to_string());
452    }
453
454    let config = serde_json::json!({
455        "wsPath": "/_live/ws",
456        "diffMode": diff_mode,
457    });
458
459    let snippet = format!(
460        r#"<script id="__web_dev_server_config">window.__WEB_DEV_SERVER_CONFIG__ = {};</script><script id="__web_dev_server_client" defer src="/_live/script.js"></script>"#,
461        serde_json::to_string(&config)?
462    );
463
464    if let Some(idx) = original.rfind("</head>") {
465        let mut result = String::with_capacity(original.len() + snippet.len() + 2);
466        result.push_str(&original[..idx]);
467        result.push('\n');
468        result.push_str(&snippet);
469        result.push('\n');
470        result.push_str(&original[idx..]);
471        Ok(result)
472    } else {
473        let mut result = original.to_string();
474        if !result.ends_with('\n') {
475            result.push('\n');
476        }
477        result.push_str(&snippet);
478        Ok(result)
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485    use notify::event::{AccessKind, DataChange, ModifyKind, RemoveKind, RenameMode};
486
487    #[test]
488    fn diff_message_serializes_resource_lowercase() {
489        let message = LiveMessage::Diff {
490            path: "/".into(),
491            resource: DiffResource::Html,
492        };
493        let json = serde_json::to_string(&message).unwrap();
494        assert!(
495            json.contains(r#""resource":"html""#),
496            "serialized json was {json}"
497        );
498    }
499
500    #[test]
501    fn access_events_are_ignored_for_diff_mode() {
502        assert!(should_ignore_event(&EventKind::Access(AccessKind::Read)));
503    }
504
505    #[test]
506    fn rename_from_events_are_ignored() {
507        let event = EventKind::Modify(ModifyKind::Name(RenameMode::From));
508        assert!(should_ignore_event(&event));
509    }
510
511    #[test]
512    fn modify_data_events_allow_diff() {
513        let event = EventKind::Modify(ModifyKind::Data(DataChange::Any));
514        assert!(allows_diff(&event));
515    }
516
517    #[test]
518    fn metadata_events_allow_diff() {
519        let event = EventKind::Modify(ModifyKind::Metadata(notify::event::MetadataKind::WriteTime));
520        assert!(allows_diff(&event));
521    }
522
523    #[test]
524    fn remove_events_force_reload_when_no_diff() {
525        let event = EventKind::Remove(RemoveKind::File);
526        assert!(should_reload_when_no_diff(&event));
527    }
528
529    #[test]
530    fn relative_paths_are_classified_within_base_dir() {
531        let base_dir =
532            std::env::temp_dir().join(format!("web_dev_server_test_{}", std::process::id()));
533        std::fs::create_dir_all(&base_dir).unwrap();
534        let canonical = std::fs::canonicalize(&base_dir).unwrap();
535        let (tx, _) = broadcast::channel(1);
536        let state = AppState {
537            base_dir: canonical,
538            broadcaster: tx,
539            diff_mode: true,
540        };
541
542        let message = classify_path(&state, Path::new("index.html"))
543            .expect("expected diff message for html file");
544
545        if let LiveMessage::Diff { path, resource } = message {
546            assert_eq!(path, "/");
547            assert!(matches!(resource, DiffResource::Html));
548        } else {
549            panic!("expected diff message");
550        }
551    }
552}