Skip to main content

typst_batch/world/
snapshot.rs

1//! Immutable file snapshot for lock-free parallel compilation.
2
3use std::path::{Path, PathBuf};
4use std::sync::Arc;
5
6use rustc_hash::FxHashMap;
7use typst::diag::FileResult;
8use typst::foundations::Bytes;
9use typst::syntax::{FileId, Source, VirtualPath};
10
11use super::path::normalize_path;
12use crate::resource::file::{decode_utf8, file_id_from_path, read_with_global_virtual};
13
14/// Error when building a file snapshot.
15#[derive(Debug)]
16pub struct SnapshotError {
17    /// The file path that failed to load.
18    pub path: PathBuf,
19    /// The underlying IO error.
20    pub source: std::io::Error,
21}
22
23impl std::fmt::Display for SnapshotError {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        write!(f, "failed to read {}: {}", self.path.display(), self.source)
26    }
27}
28
29impl std::error::Error for SnapshotError {
30    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
31        Some(&self.source)
32    }
33}
34
35/// Configuration for building a snapshot with prelude/postlude injection.
36#[derive(Default, Clone)]
37pub struct SnapshotConfig {
38    /// Code to inject at the beginning of each main file.
39    pub prelude: Option<String>,
40    /// Code to inject at the end of each main file.
41    pub postlude: Option<String>,
42}
43
44/// Immutable file content snapshot for lock-free parallel access.
45///
46/// Built once before parallel compilation, then shared across all threads.
47#[derive(Clone)]
48pub struct FileSnapshot {
49    sources: Arc<FxHashMap<FileId, Source>>,
50    files: Arc<FxHashMap<FileId, Bytes>>,
51}
52
53impl FileSnapshot {
54    /// Build a snapshot by pre-scanning all content files and their imports.
55    ///
56    /// Returns an error if any content file fails to load.
57    pub fn build(content_files: &[PathBuf], root: &Path) -> Result<Self, SnapshotError> {
58        Self::build_with_config(content_files, root, &SnapshotConfig::default(), |_| {})
59    }
60
61    /// Build a snapshot with callback for each file loaded.
62    ///
63    /// Returns an error if any content file fails to load.
64    pub fn build_each(
65        content_files: &[PathBuf],
66        root: &Path,
67        on_load: impl Fn(&Path) + Sync,
68    ) -> Result<Self, SnapshotError> {
69        Self::build_with_config(content_files, root, &SnapshotConfig::default(), on_load)
70    }
71
72    /// Build a snapshot with prelude/postlude injection.
73    ///
74    /// The prelude is injected at the beginning of each main file, and its imports
75    /// are also included in the snapshot. This ensures all dependencies are available
76    /// during compilation.
77    pub fn build_with_config(
78        content_files: &[PathBuf],
79        root: &Path,
80        config: &SnapshotConfig,
81        on_load: impl Fn(&Path) + Sync,
82    ) -> Result<Self, SnapshotError> {
83        let root = normalize_path(root);
84
85        // Collect main file IDs for prelude injection
86        let main_ids: rustc_hash::FxHashSet<FileId> = content_files
87            .iter()
88            .filter_map(|p| file_id_from_path(p, &root))
89            .collect();
90
91        let sources = load_sources_with_imports(content_files, &root, config, &main_ids, on_load)?;
92
93        Ok(Self {
94            sources: Arc::new(sources),
95            files: Arc::new(FxHashMap::default()),
96        })
97    }
98
99    /// Gets a cached source by file ID.
100    #[inline]
101    pub fn get_source(&self, id: FileId) -> Option<Source> {
102        self.sources.get(&id).cloned()
103    }
104
105    /// Gets cached file bytes by file ID.
106    #[inline]
107    pub fn get_file(&self, id: FileId) -> Option<Bytes> {
108        self.files.get(&id).cloned()
109    }
110
111    /// Returns the number of cached sources.
112    #[inline]
113    pub fn source_count(&self) -> usize {
114        self.sources.len()
115    }
116}
117
118// ============================================================================
119// Source Loading
120// ============================================================================
121
122fn load_sources_with_imports(
123    content_files: &[PathBuf],
124    root: &Path,
125    config: &SnapshotConfig,
126    main_ids: &rustc_hash::FxHashSet<FileId>,
127    on_load: impl Fn(&Path) + Sync,
128) -> Result<FxHashMap<FileId, Source>, SnapshotError> {
129    use rayon::prelude::*;
130    use std::sync::Mutex;
131
132    let sources = Mutex::new(FxHashMap::default());
133    let first_error: Mutex<Option<SnapshotError>> = Mutex::new(None);
134
135    // Load initial files in parallel (with prelude/postlude injection for main files)
136    let initial: Vec<_> = content_files
137        .par_iter()
138        .filter_map(|path| {
139            // Skip if we already have an error
140            if first_error.lock().unwrap().is_some() {
141                return None;
142            }
143
144            let id = match file_id_from_path(path, root) {
145                Some(id) => id,
146                None => return None, // Path outside root, skip
147            };
148
149            match load_source_with_injection(id, root, config, main_ids) {
150                Ok(source) => {
151                    on_load(path);
152                    Some((id, source))
153                }
154                Err(_) => {
155                    // Record the first error
156                    let mut err = first_error.lock().unwrap();
157                    if err.is_none() {
158                        *err = Some(SnapshotError {
159                            path: path.clone(),
160                            source: std::io::Error::new(
161                                std::io::ErrorKind::NotFound,
162                                format!("failed to load source: {}", path.display()),
163                            ),
164                        });
165                    }
166                    None
167                }
168            }
169        })
170        .collect();
171
172    // Check for errors
173    if let Some(err) = first_error.into_inner().unwrap() {
174        return Err(err);
175    }
176
177    // Collect imports from initial files (prelude imports are included since prelude was injected)
178    let mut pending: Vec<FileId> = Vec::new();
179    for (id, source) in initial {
180        pending.extend(parse_imports(&source));
181        sources.lock().unwrap().insert(id, source);
182    }
183
184    // BFS to load all imports (imports are optional, skip failures)
185    while !pending.is_empty() {
186        let batch: Vec<_> = {
187            let sources = sources.lock().unwrap();
188            pending
189                .drain(..)
190                .filter(|id| !sources.contains_key(id))
191                .collect()
192        };
193
194        if batch.is_empty() {
195            break;
196        }
197
198        // For imports, we silently skip failures (they might be package imports
199        // or optional files that will be handled at compile time)
200        let results: Vec<_> = batch
201            .par_iter()
202            .filter_map(|&id| load_source(id, root).ok().map(|s| (id, s)))
203            .collect();
204
205        let mut sources = sources.lock().unwrap();
206        for (id, source) in results {
207            if sources.contains_key(&id) {
208                continue;
209            }
210            for import_id in parse_imports(&source) {
211                if !sources.contains_key(&import_id) {
212                    pending.push(import_id);
213                }
214            }
215            sources.insert(id, source);
216        }
217    }
218
219    Ok(sources.into_inner().unwrap())
220}
221
222/// Load source with prelude/postlude injection for main files.
223fn load_source_with_injection(
224    id: FileId,
225    root: &Path,
226    config: &SnapshotConfig,
227    main_ids: &rustc_hash::FxHashSet<FileId>,
228) -> FileResult<Source> {
229    let bytes = read_with_global_virtual(id, root)?;
230    let text = decode_utf8(&bytes)?;
231
232    // Inject prelude/postlude for main files
233    let text = if main_ids.contains(&id) {
234        let mut result = String::new();
235        if let Some(prelude) = &config.prelude {
236            result.push_str(prelude);
237            result.push('\n');
238        }
239        result.push_str(text);
240        if let Some(postlude) = &config.postlude {
241            result.push('\n');
242            result.push_str(postlude);
243        }
244        result
245    } else {
246        text.into()
247    };
248
249    Ok(Source::new(id, text))
250}
251
252fn load_source(id: FileId, root: &Path) -> FileResult<Source> {
253    let bytes = read_with_global_virtual(id, root)?;
254    let text = decode_utf8(&bytes)?;
255    Ok(Source::new(id, text.into()))
256}
257
258// ============================================================================
259// Import Parsing
260// ============================================================================
261
262fn parse_imports(source: &Source) -> Vec<FileId> {
263    use typst::syntax::{ast, SyntaxKind};
264
265    let mut imports = Vec::new();
266    let mut stack = vec![source.root().clone()];
267    let current = source.id();
268
269    while let Some(node) = stack.pop() {
270        match node.kind() {
271            SyntaxKind::ModuleImport => {
272                if let Some(import) = node.cast::<ast::ModuleImport>()
273                    && let Some(id) = resolve_import_path(&import.source(), current) {
274                        imports.push(id);
275                    }
276            }
277            SyntaxKind::ModuleInclude => {
278                if let Some(include) = node.cast::<ast::ModuleInclude>()
279                    && let Some(id) = resolve_import_path(&include.source(), current) {
280                        imports.push(id);
281                    }
282            }
283            _ => stack.extend(node.children().cloned()),
284        }
285    }
286
287    imports
288}
289
290fn resolve_import_path(expr: &typst::syntax::ast::Expr, current: FileId) -> Option<FileId> {
291    use typst::syntax::ast;
292
293    let path_str = match expr {
294        ast::Expr::Str(s) => s.get(),
295        _ => return None,
296    };
297
298    // Skip package imports
299    if path_str.starts_with('@') {
300        return None;
301    }
302
303    let resolved = if path_str.starts_with('/') {
304        VirtualPath::new(&*path_str)
305    } else {
306        current.vpath().join(&*path_str)
307    };
308
309    Some(FileId::new(None, resolved))
310}