1use std::path::{Path, PathBuf};
4use std::sync::Arc;
5
6use rustc_hash::FxHashMap;
7use typst::diag::FileResult;
8use typst::foundations::Bytes;
9use typst::syntax::{FileId, Source, VirtualPath};
10
11use super::path::normalize_path;
12use crate::resource::file::{decode_utf8, file_id_from_path, read_with_global_virtual};
13
14#[derive(Debug)]
16pub struct SnapshotError {
17 pub path: PathBuf,
19 pub source: std::io::Error,
21}
22
23impl std::fmt::Display for SnapshotError {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 write!(f, "failed to read {}: {}", self.path.display(), self.source)
26 }
27}
28
29impl std::error::Error for SnapshotError {
30 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
31 Some(&self.source)
32 }
33}
34
35#[derive(Default, Clone)]
37pub struct SnapshotConfig {
38 pub prelude: Option<String>,
40 pub postlude: Option<String>,
42}
43
44#[derive(Clone)]
48pub struct FileSnapshot {
49 sources: Arc<FxHashMap<FileId, Source>>,
50 files: Arc<FxHashMap<FileId, Bytes>>,
51}
52
53impl FileSnapshot {
54 pub fn build(content_files: &[PathBuf], root: &Path) -> Result<Self, SnapshotError> {
58 Self::build_with_config(content_files, root, &SnapshotConfig::default(), |_| {})
59 }
60
61 pub fn build_each(
65 content_files: &[PathBuf],
66 root: &Path,
67 on_load: impl Fn(&Path) + Sync,
68 ) -> Result<Self, SnapshotError> {
69 Self::build_with_config(content_files, root, &SnapshotConfig::default(), on_load)
70 }
71
72 pub fn build_with_config(
78 content_files: &[PathBuf],
79 root: &Path,
80 config: &SnapshotConfig,
81 on_load: impl Fn(&Path) + Sync,
82 ) -> Result<Self, SnapshotError> {
83 let root = normalize_path(root);
84
85 let main_ids: rustc_hash::FxHashSet<FileId> = content_files
87 .iter()
88 .filter_map(|p| file_id_from_path(p, &root))
89 .collect();
90
91 let sources = load_sources_with_imports(content_files, &root, config, &main_ids, on_load)?;
92
93 Ok(Self {
94 sources: Arc::new(sources),
95 files: Arc::new(FxHashMap::default()),
96 })
97 }
98
99 #[inline]
101 pub fn get_source(&self, id: FileId) -> Option<Source> {
102 self.sources.get(&id).cloned()
103 }
104
105 #[inline]
107 pub fn get_file(&self, id: FileId) -> Option<Bytes> {
108 self.files.get(&id).cloned()
109 }
110
111 #[inline]
113 pub fn source_count(&self) -> usize {
114 self.sources.len()
115 }
116}
117
118fn load_sources_with_imports(
123 content_files: &[PathBuf],
124 root: &Path,
125 config: &SnapshotConfig,
126 main_ids: &rustc_hash::FxHashSet<FileId>,
127 on_load: impl Fn(&Path) + Sync,
128) -> Result<FxHashMap<FileId, Source>, SnapshotError> {
129 use rayon::prelude::*;
130 use std::sync::Mutex;
131
132 let sources = Mutex::new(FxHashMap::default());
133 let first_error: Mutex<Option<SnapshotError>> = Mutex::new(None);
134
135 let initial: Vec<_> = content_files
137 .par_iter()
138 .filter_map(|path| {
139 if first_error.lock().unwrap().is_some() {
141 return None;
142 }
143
144 let id = match file_id_from_path(path, root) {
145 Some(id) => id,
146 None => return None, };
148
149 match load_source_with_injection(id, root, config, main_ids) {
150 Ok(source) => {
151 on_load(path);
152 Some((id, source))
153 }
154 Err(_) => {
155 let mut err = first_error.lock().unwrap();
157 if err.is_none() {
158 *err = Some(SnapshotError {
159 path: path.clone(),
160 source: std::io::Error::new(
161 std::io::ErrorKind::NotFound,
162 format!("failed to load source: {}", path.display()),
163 ),
164 });
165 }
166 None
167 }
168 }
169 })
170 .collect();
171
172 if let Some(err) = first_error.into_inner().unwrap() {
174 return Err(err);
175 }
176
177 let mut pending: Vec<FileId> = Vec::new();
179 for (id, source) in initial {
180 pending.extend(parse_imports(&source));
181 sources.lock().unwrap().insert(id, source);
182 }
183
184 while !pending.is_empty() {
186 let batch: Vec<_> = {
187 let sources = sources.lock().unwrap();
188 pending
189 .drain(..)
190 .filter(|id| !sources.contains_key(id))
191 .collect()
192 };
193
194 if batch.is_empty() {
195 break;
196 }
197
198 let results: Vec<_> = batch
201 .par_iter()
202 .filter_map(|&id| load_source(id, root).ok().map(|s| (id, s)))
203 .collect();
204
205 let mut sources = sources.lock().unwrap();
206 for (id, source) in results {
207 if sources.contains_key(&id) {
208 continue;
209 }
210 for import_id in parse_imports(&source) {
211 if !sources.contains_key(&import_id) {
212 pending.push(import_id);
213 }
214 }
215 sources.insert(id, source);
216 }
217 }
218
219 Ok(sources.into_inner().unwrap())
220}
221
222fn load_source_with_injection(
224 id: FileId,
225 root: &Path,
226 config: &SnapshotConfig,
227 main_ids: &rustc_hash::FxHashSet<FileId>,
228) -> FileResult<Source> {
229 let bytes = read_with_global_virtual(id, root)?;
230 let text = decode_utf8(&bytes)?;
231
232 let text = if main_ids.contains(&id) {
234 let mut result = String::new();
235 if let Some(prelude) = &config.prelude {
236 result.push_str(prelude);
237 result.push('\n');
238 }
239 result.push_str(text);
240 if let Some(postlude) = &config.postlude {
241 result.push('\n');
242 result.push_str(postlude);
243 }
244 result
245 } else {
246 text.into()
247 };
248
249 Ok(Source::new(id, text))
250}
251
252fn load_source(id: FileId, root: &Path) -> FileResult<Source> {
253 let bytes = read_with_global_virtual(id, root)?;
254 let text = decode_utf8(&bytes)?;
255 Ok(Source::new(id, text.into()))
256}
257
258fn parse_imports(source: &Source) -> Vec<FileId> {
263 use typst::syntax::{ast, SyntaxKind};
264
265 let mut imports = Vec::new();
266 let mut stack = vec![source.root().clone()];
267 let current = source.id();
268
269 while let Some(node) = stack.pop() {
270 match node.kind() {
271 SyntaxKind::ModuleImport => {
272 if let Some(import) = node.cast::<ast::ModuleImport>()
273 && let Some(id) = resolve_import_path(&import.source(), current) {
274 imports.push(id);
275 }
276 }
277 SyntaxKind::ModuleInclude => {
278 if let Some(include) = node.cast::<ast::ModuleInclude>()
279 && let Some(id) = resolve_import_path(&include.source(), current) {
280 imports.push(id);
281 }
282 }
283 _ => stack.extend(node.children().cloned()),
284 }
285 }
286
287 imports
288}
289
290fn resolve_import_path(expr: &typst::syntax::ast::Expr, current: FileId) -> Option<FileId> {
291 use typst::syntax::ast;
292
293 let path_str = match expr {
294 ast::Expr::Str(s) => s.get(),
295 _ => return None,
296 };
297
298 if path_str.starts_with('@') {
300 return None;
301 }
302
303 let resolved = if path_str.starts_with('/') {
304 VirtualPath::new(&*path_str)
305 } else {
306 current.vpath().join(&*path_str)
307 };
308
309 Some(FileId::new(None, resolved))
310}