Skip to main content

yulang_monomorphize/
cache.rs

1use std::{
2    collections::{HashMap, hash_map::DefaultHasher},
3    fs,
4    hash::{Hash, Hasher},
5    io,
6    path::{Path, PathBuf},
7};
8
9use serde::{Deserialize, Serialize};
10use yulang_runtime_ir::{
11    FinalizedBinding as Binding, FinalizedExpr as Expr, FinalizedExprKind as ExprKind, RuntimeType,
12};
13use yulang_sources::{CompiledUnitManifest, YulangCachePaths};
14use yulang_typed_ir as typed_ir;
15
16pub const MONOMORPHIZE_INSTANCE_CACHE_FORMAT_VERSION: u32 = 2;
17
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct MonomorphizeInstanceArtifactCache {
20    root: PathBuf,
21}
22
23impl MonomorphizeInstanceArtifactCache {
24    pub fn new(root: impl Into<PathBuf>) -> Self {
25        Self { root: root.into() }
26    }
27
28    pub fn from_paths(paths: &YulangCachePaths) -> Self {
29        Self::new(paths.compiled_units.clone())
30    }
31
32    pub fn read_for_manifests(
33        &self,
34        manifests: &[CompiledUnitManifest],
35    ) -> Result<MonomorphizeInstanceCacheSurface, MonomorphizeInstanceArtifactCacheError> {
36        let key = MonomorphizeInstanceArtifactCacheKey::from_manifests(manifests)?;
37        let path = self.artifact_path(&key);
38        let bytes =
39            fs::read(&path).map_err(|error| MonomorphizeInstanceArtifactCacheError::Io {
40                path: path.clone(),
41                error: io_error_string(error),
42            })?;
43        let surface =
44            postcard::from_bytes::<MonomorphizeInstanceCacheSurface>(&bytes).map_err(|error| {
45                MonomorphizeInstanceArtifactCacheError::Deserialize {
46                    path: path.clone(),
47                    error: error.to_string(),
48                }
49            })?;
50        if surface.format_version != MONOMORPHIZE_INSTANCE_CACHE_FORMAT_VERSION {
51            return Err(
52                MonomorphizeInstanceArtifactCacheError::UnsupportedFinalizeFormat {
53                    format_version: surface.format_version,
54                },
55            );
56        }
57        Ok(surface)
58    }
59
60    pub fn read_cache_for_manifests(
61        &self,
62        manifests: &[CompiledUnitManifest],
63    ) -> MonomorphizeInstanceCache {
64        self.read_for_manifests(manifests)
65            .map(MonomorphizeInstanceCache::from_surface)
66            .unwrap_or_default()
67    }
68
69    pub fn write_cache_for_manifests(
70        &self,
71        manifests: &[CompiledUnitManifest],
72        cache: &MonomorphizeInstanceCache,
73    ) -> Result<PathBuf, MonomorphizeInstanceArtifactCacheError> {
74        self.write_for_manifests(manifests, &cache.to_surface())
75    }
76
77    pub fn write_for_manifests(
78        &self,
79        manifests: &[CompiledUnitManifest],
80        surface: &MonomorphizeInstanceCacheSurface,
81    ) -> Result<PathBuf, MonomorphizeInstanceArtifactCacheError> {
82        let key = MonomorphizeInstanceArtifactCacheKey::from_manifests(manifests)?;
83        let path = self.artifact_path(&key);
84        if let Some(parent) = path.parent() {
85            fs::create_dir_all(parent).map_err(|error| {
86                MonomorphizeInstanceArtifactCacheError::Io {
87                    path: parent.to_path_buf(),
88                    error: io_error_string(error),
89                }
90            })?;
91        }
92        let bytes = postcard::to_allocvec(surface).map_err(|error| {
93            MonomorphizeInstanceArtifactCacheError::Serialize {
94                error: error.to_string(),
95            }
96        })?;
97        fs::write(&path, bytes).map_err(|error| MonomorphizeInstanceArtifactCacheError::Io {
98            path: path.clone(),
99            error: io_error_string(error),
100        })?;
101        Ok(path)
102    }
103
104    fn artifact_path(&self, key: &MonomorphizeInstanceArtifactCacheKey) -> PathBuf {
105        key.directory(&self.root).join(key.file_name())
106    }
107}
108
109#[derive(Debug, Clone, PartialEq, Eq, Hash)]
110pub struct MonomorphizeInstanceArtifactCacheKey {
111    pub compiled_artifact_format_version: u32,
112    pub parser_format_version: u32,
113    pub finalize_format_version: u32,
114    pub unit_count: usize,
115    pub manifest_hash: u64,
116}
117
118impl MonomorphizeInstanceArtifactCacheKey {
119    pub fn from_manifests(
120        manifests: &[CompiledUnitManifest],
121    ) -> Result<Self, MonomorphizeInstanceArtifactCacheError> {
122        let Some(first) = manifests.first() else {
123            return Err(MonomorphizeInstanceArtifactCacheError::EmptyManifestSet);
124        };
125        for manifest in manifests {
126            if manifest.artifact_format_version != first.artifact_format_version
127                || manifest.parser_format_version != first.parser_format_version
128            {
129                return Err(MonomorphizeInstanceArtifactCacheError::MixedCompiledFormats);
130            }
131        }
132        Ok(Self {
133            compiled_artifact_format_version: first.artifact_format_version,
134            parser_format_version: first.parser_format_version,
135            finalize_format_version: MONOMORPHIZE_INSTANCE_CACHE_FORMAT_VERSION,
136            unit_count: manifests.len(),
137            manifest_hash: hash_compiled_unit_manifests(manifests),
138        })
139    }
140
141    fn directory(&self, root: &Path) -> PathBuf {
142        root.join(format!("v{}", self.compiled_artifact_format_version))
143            .join(format!("parser-v{}", self.parser_format_version))
144            .join(format!(
145                "runtime-finalize-v{}",
146                self.finalize_format_version
147            ))
148    }
149
150    fn file_name(&self) -> String {
151        format!(
152            "instances-{}-{:016x}.bin",
153            self.unit_count, self.manifest_hash
154        )
155    }
156}
157
158#[derive(Debug, Clone, PartialEq, Eq)]
159pub enum MonomorphizeInstanceArtifactCacheError {
160    EmptyManifestSet,
161    MixedCompiledFormats,
162    UnsupportedFinalizeFormat { format_version: u32 },
163    Io { path: PathBuf, error: String },
164    Serialize { error: String },
165    Deserialize { path: PathBuf, error: String },
166}
167
168#[derive(Debug, Clone, Default, PartialEq, Eq)]
169pub struct MonomorphizeInstanceCache {
170    entries: HashMap<MonomorphizeInstanceKey, CachedMonomorphizeInstance>,
171    policy: MonomorphizeInstanceCachePolicy,
172    profile: MonomorphizeInstanceCacheProfile,
173}
174
175impl MonomorphizeInstanceCache {
176    pub fn new(policy: MonomorphizeInstanceCachePolicy) -> Self {
177        Self {
178            policy,
179            ..Self::default()
180        }
181    }
182
183    pub fn from_surface(surface: MonomorphizeInstanceCacheSurface) -> Self {
184        if surface.format_version != MONOMORPHIZE_INSTANCE_CACHE_FORMAT_VERSION {
185            return Self::default();
186        }
187        let entries = surface
188            .instances
189            .into_iter()
190            .map(|instance| (instance.key.clone(), instance))
191            .collect();
192        Self {
193            entries,
194            policy: MonomorphizeInstanceCachePolicy::default(),
195            profile: MonomorphizeInstanceCacheProfile::default(),
196        }
197    }
198
199    pub fn to_surface(&self) -> MonomorphizeInstanceCacheSurface {
200        MonomorphizeInstanceCacheSurface {
201            format_version: MONOMORPHIZE_INSTANCE_CACHE_FORMAT_VERSION,
202            instances: self.entries.values().cloned().collect(),
203        }
204    }
205
206    pub fn profile(&self) -> MonomorphizeInstanceCacheProfile {
207        self.profile
208    }
209
210    pub fn get(&mut self, key: &MonomorphizeInstanceKey) -> Option<&CachedMonomorphizeInstance> {
211        let instance = self.entries.get(key);
212        if instance.is_some() {
213            self.profile.hits += 1;
214        } else {
215            self.profile.misses += 1;
216        }
217        instance
218    }
219
220    pub fn insert(&mut self, instance: CachedMonomorphizeInstance) {
221        if self.entries.contains_key(&instance.key) {
222            return;
223        }
224        if self
225            .policy
226            .max_entries
227            .is_some_and(|max_entries| self.entries.len() >= max_entries)
228        {
229            self.profile.skipped_full += 1;
230            return;
231        }
232        let nodes = expr_node_count(&instance.body);
233        if self
234            .policy
235            .max_body_nodes
236            .is_some_and(|max_nodes| nodes > max_nodes)
237        {
238            self.profile.skipped_large_body += 1;
239            return;
240        }
241        self.profile.inserts += 1;
242        self.entries.insert(instance.key.clone(), instance);
243    }
244}
245
246#[derive(Debug, Clone, Copy, PartialEq, Eq)]
247pub struct MonomorphizeInstanceCachePolicy {
248    pub max_entries: Option<usize>,
249    pub max_body_nodes: Option<usize>,
250}
251
252impl Default for MonomorphizeInstanceCachePolicy {
253    fn default() -> Self {
254        Self {
255            max_entries: Some(4096),
256            max_body_nodes: Some(2048),
257        }
258    }
259}
260
261#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
262pub struct MonomorphizeInstanceCacheProfile {
263    pub hits: usize,
264    pub misses: usize,
265    pub inserts: usize,
266    pub skipped_full: usize,
267    pub skipped_large_body: usize,
268}
269
270#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
271pub struct MonomorphizeInstanceCacheSurface {
272    pub format_version: u32,
273    pub instances: Vec<CachedMonomorphizeInstance>,
274}
275
276impl Default for MonomorphizeInstanceCacheSurface {
277    fn default() -> Self {
278        Self {
279            format_version: MONOMORPHIZE_INSTANCE_CACHE_FORMAT_VERSION,
280            instances: Vec::new(),
281        }
282    }
283}
284
285#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
286pub struct MonomorphizeInstanceKey {
287    pub binding: typed_ir::Path,
288    pub substitutions: Vec<typed_ir::TypeSubstitution>,
289    pub callee_type: RuntimeType,
290}
291
292#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
293pub struct CachedMonomorphizeInstance {
294    pub key: MonomorphizeInstanceKey,
295    pub scheme: typed_ir::Scheme,
296    pub body: Expr,
297    pub callee_type: RuntimeType,
298    pub result_type: RuntimeType,
299}
300
301impl CachedMonomorphizeInstance {
302    pub fn binding_with_alias(&self, alias: typed_ir::Path) -> Binding {
303        Binding {
304            name: alias,
305            type_params: Vec::new(),
306            scheme: self.scheme.clone(),
307            body: self.body.clone(),
308        }
309    }
310}
311
312fn expr_node_count(expr: &Expr) -> usize {
313    1 + match &expr.kind {
314        ExprKind::Lambda { body, .. }
315        | ExprKind::BindHere { expr: body }
316        | ExprKind::Thunk { expr: body, .. }
317        | ExprKind::LocalPushId { body, .. }
318        | ExprKind::AddId { thunk: body, .. }
319        | ExprKind::Coerce { expr: body, .. }
320        | ExprKind::Pack { expr: body, .. }
321        | ExprKind::Select { base: body, .. } => expr_node_count(body),
322        ExprKind::Apply { callee, arg, .. } => expr_node_count(callee) + expr_node_count(arg),
323        ExprKind::If {
324            cond,
325            then_branch,
326            else_branch,
327            ..
328        } => expr_node_count(cond) + expr_node_count(then_branch) + expr_node_count(else_branch),
329        ExprKind::Tuple(items) => items.iter().map(expr_node_count).sum(),
330        ExprKind::Record { fields, spread } => {
331            fields
332                .iter()
333                .map(|field| expr_node_count(&field.value))
334                .sum::<usize>()
335                + spread.as_ref().map_or(0, record_spread_expr_node_count)
336        }
337        ExprKind::Variant { value, .. } => value.as_deref().map_or(0, expr_node_count),
338        ExprKind::Match {
339            scrutinee, arms, ..
340        } => {
341            expr_node_count(scrutinee)
342                + arms
343                    .iter()
344                    .map(|arm| {
345                        arm.guard.as_ref().map_or(0, expr_node_count) + expr_node_count(&arm.body)
346                    })
347                    .sum::<usize>()
348        }
349        ExprKind::Block { stmts, tail } => {
350            stmts.iter().map(stmt_node_count).sum::<usize>()
351                + tail.as_deref().map_or(0, expr_node_count)
352        }
353        ExprKind::Handle { body, arms, .. } => {
354            expr_node_count(body)
355                + arms
356                    .iter()
357                    .map(|arm| {
358                        arm.guard.as_ref().map_or(0, expr_node_count) + expr_node_count(&arm.body)
359                    })
360                    .sum::<usize>()
361        }
362        ExprKind::Var(_)
363        | ExprKind::EffectOp(_)
364        | ExprKind::PrimitiveOp(_)
365        | ExprKind::Lit(_)
366        | ExprKind::PeekId
367        | ExprKind::FindId { .. } => 0,
368    }
369}
370
371fn record_spread_expr_node_count(spread: &yulang_runtime_ir::FinalizedRecordSpreadExpr) -> usize {
372    match spread {
373        yulang_runtime_ir::FinalizedRecordSpreadExpr::Head(expr)
374        | yulang_runtime_ir::FinalizedRecordSpreadExpr::Tail(expr) => expr_node_count(expr),
375    }
376}
377
378fn stmt_node_count(stmt: &yulang_runtime_ir::FinalizedStmt) -> usize {
379    match stmt {
380        yulang_runtime_ir::FinalizedStmt::Let { value, .. } => expr_node_count(value),
381        yulang_runtime_ir::FinalizedStmt::Expr(expr)
382        | yulang_runtime_ir::FinalizedStmt::Module { body: expr, .. } => expr_node_count(expr),
383    }
384}
385
386fn hash_compiled_unit_manifests(manifests: &[CompiledUnitManifest]) -> u64 {
387    let mut hasher = DefaultHasher::new();
388    for manifest in manifests {
389        manifest.artifact_format_version.hash(&mut hasher);
390        manifest.parser_format_version.hash(&mut hasher);
391        manifest.unit_index.hash(&mut hasher);
392        source_compilation_unit_origin_key(manifest.origin).hash(&mut hasher);
393        for realm in &manifest.realms {
394            realm.identity.hash(&mut hasher);
395            realm.version.hash(&mut hasher);
396        }
397        for band in &manifest.bands {
398            band.realm.identity.hash(&mut hasher);
399            band.realm.version.hash(&mut hasher);
400            band.band.segments.hash(&mut hasher);
401        }
402        manifest.source_hash.hash(&mut hasher);
403        manifest.syntax_hash.hash(&mut hasher);
404        manifest.interface_hash.hash(&mut hasher);
405        for file in &manifest.files {
406            file.path.hash(&mut hasher);
407            file.module_path.segments.hash(&mut hasher);
408            file.origin.hash(&mut hasher);
409            file.source_len.hash(&mut hasher);
410            file.source_hash.hash(&mut hasher);
411        }
412        for dependency in &manifest.dependencies {
413            dependency.unit_index.hash(&mut hasher);
414            dependency.source_hash.hash(&mut hasher);
415            dependency.interface_hash.hash(&mut hasher);
416        }
417    }
418    hasher.finish()
419}
420
421fn source_compilation_unit_origin_key(origin: yulang_sources::SourceCompilationUnitOrigin) -> u8 {
422    match origin {
423        yulang_sources::SourceCompilationUnitOrigin::Entry => 0,
424        yulang_sources::SourceCompilationUnitOrigin::Std => 1,
425        yulang_sources::SourceCompilationUnitOrigin::User => 2,
426        yulang_sources::SourceCompilationUnitOrigin::Mixed => 3,
427    }
428}
429
430fn io_error_string(error: io::Error) -> String {
431    match error.kind() {
432        io::ErrorKind::NotFound => "not found".to_string(),
433        _ => error.to_string(),
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440    use yulang_sources::{
441        CompiledSourceFileIdentity, CompiledUnitDependency, SourceCompilationUnitOrigin,
442        SourceOrigin,
443    };
444
445    #[test]
446    fn artifact_cache_uses_compiled_unit_manifest_key() {
447        let root =
448            std::env::temp_dir().join(format!("yulang-finalize-cache-test-{}", std::process::id()));
449        let _ = fs::remove_dir_all(&root);
450        let cache = MonomorphizeInstanceArtifactCache::new(&root);
451        let manifests = vec![manifest(0, 11), manifest(1, 29)];
452        let surface = MonomorphizeInstanceCacheSurface {
453            format_version: MONOMORPHIZE_INSTANCE_CACHE_FORMAT_VERSION,
454            instances: vec![cached_instance()],
455        };
456
457        let path = cache
458            .write_for_manifests(&manifests, &surface)
459            .expect("write finalize instance cache");
460        assert!(
461            path.components()
462                .any(|component| component.as_os_str() == "runtime-finalize-v2")
463        );
464
465        let restored = cache
466            .read_for_manifests(&manifests)
467            .expect("read finalize instance cache");
468        let restored_cache = cache.read_cache_for_manifests(&manifests);
469        let _ = fs::remove_dir_all(&root);
470
471        assert_eq!(restored, surface);
472        assert_eq!(restored_cache.to_surface(), surface);
473    }
474
475    #[test]
476    fn artifact_cache_rejects_mixed_compiled_formats() {
477        let mut manifests = vec![manifest(0, 11), manifest(1, 29)];
478        manifests[1].artifact_format_version += 1;
479
480        assert_eq!(
481            MonomorphizeInstanceArtifactCacheKey::from_manifests(&manifests),
482            Err(MonomorphizeInstanceArtifactCacheError::MixedCompiledFormats)
483        );
484    }
485
486    fn cached_instance() -> CachedMonomorphizeInstance {
487        let int = typed_ir::Type::Named {
488            path: typed_ir::Path::from_name(typed_ir::Name("int".into())),
489            args: Vec::new(),
490        };
491        CachedMonomorphizeInstance {
492            key: MonomorphizeInstanceKey {
493                binding: typed_ir::Path::from_name(typed_ir::Name("id".into())),
494                substitutions: vec![typed_ir::TypeSubstitution {
495                    var: typed_ir::TypeVar("a".into()),
496                    ty: int.clone(),
497                }],
498                callee_type: RuntimeType::Value(int.clone()),
499            },
500            scheme: typed_ir::Scheme {
501                requirements: Vec::new(),
502                body: int.clone(),
503            },
504            body: Expr::typed(
505                ExprKind::Lit(typed_ir::Lit::Int("1".into())),
506                RuntimeType::Value(int.clone()),
507            ),
508            callee_type: RuntimeType::Value(int.clone()),
509            result_type: RuntimeType::Value(int),
510        }
511    }
512
513    fn manifest(unit_index: usize, hash: u64) -> CompiledUnitManifest {
514        CompiledUnitManifest {
515            artifact_format_version: 17,
516            parser_format_version: 1,
517            unit_index,
518            origin: SourceCompilationUnitOrigin::Std,
519            realms: Vec::new(),
520            bands: Vec::new(),
521            files: vec![CompiledSourceFileIdentity {
522                path: format!("std/{unit_index}.yu"),
523                module_path: typed_ir::Path::from_name(typed_ir::Name(format!("m{unit_index}"))),
524                origin: SourceOrigin::Std,
525                source_len: 10,
526                source_hash: hash,
527            }],
528            dependencies: (unit_index > 0)
529                .then(|| CompiledUnitDependency {
530                    unit_index: unit_index - 1,
531                    source_hash: hash - 1,
532                    interface_hash: hash + 1,
533                })
534                .into_iter()
535                .collect(),
536            source_hash: hash,
537            syntax_hash: hash + 10,
538            interface_hash: hash + 20,
539        }
540    }
541}