Skip to main content

vyre_driver/
aot.rs

1//! Backend-neutral AOT emission and launcher registries.
2
3use std::collections::BTreeMap;
4use std::path::PathBuf;
5
6use crate::{BackendError, DispatchConfig};
7use vyre_foundation::ir::Program;
8
9/// Stable AOT target identifier.
10pub type AotTargetId = &'static str;
11
12/// One backend-owned AOT emitter.
13pub struct AotEmitter {
14    /// Stable target identifier.
15    pub target: AotTargetId,
16    /// Emit target-native bytes for `program`.
17    pub emit: fn(&Program, &DispatchConfig) -> Result<Vec<u8>, String>,
18}
19
20inventory::collect!(AotEmitter);
21
22/// One dependency entry required by a generated launcher crate.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct LauncherDependency {
25    /// Dependency name in the emitted `Cargo.toml`.
26    pub name: &'static str,
27    /// Inline dependency spec, for example `{ version = "1", features = ["derive"] }`.
28    pub spec: &'static str,
29}
30
31/// Backend-neutral launcher emission request.
32#[derive(Debug)]
33pub struct AotLauncherRequest<'a> {
34    /// Stable target id matching [`AotEmitter::target`].
35    pub target: AotTargetId,
36    /// Generated launcher crate name.
37    pub crate_name: &'a str,
38    /// Whether to include target-owned collective/multi-rank support.
39    pub include_collectives: bool,
40    /// Whether to include a built-in eval-time training loop.
41    pub include_ttt_loop: bool,
42}
43
44/// Source files and manifest additions produced by a target-owned launcher emitter.
45#[derive(Debug, Clone, Default)]
46pub struct AotLauncherFiles {
47    /// Additional dependencies required by target-specific launcher files.
48    pub dependencies: Vec<LauncherDependency>,
49    /// Source files keyed by launcher-crate-relative path.
50    pub files: BTreeMap<PathBuf, String>,
51}
52
53impl AotLauncherFiles {
54    /// Build launcher files from a fixed backend emission list.
55    ///
56    /// Backends should emit files in a deterministic order and delegate the
57    /// final path-keyed container construction here instead of open-coding
58    /// per-backend map assembly.
59    #[must_use]
60    pub fn from_entries(
61        dependencies: Vec<LauncherDependency>,
62        entries: impl IntoIterator<Item = (PathBuf, String)>,
63    ) -> Self {
64        Self {
65            dependencies,
66            files: entries.into_iter().collect(),
67        }
68    }
69}
70
71/// One backend-owned launcher source emitter.
72pub struct AotLauncherEmitter {
73    /// Stable target identifier.
74    pub target: AotTargetId,
75    /// Emit target-owned launcher files for `request`.
76    pub emit: fn(&AotLauncherRequest<'_>) -> Result<AotLauncherFiles, String>,
77}
78
79inventory::collect!(AotLauncherEmitter);
80
81/// Return every linked AOT emitter.
82#[must_use]
83pub fn registered_aot_emitters() -> Vec<&'static AotEmitter> {
84    let emitter_count = inventory::iter::<AotEmitter>.into_iter().count();
85    let mut emitters = Vec::new();
86    let _ = emitters.try_reserve_exact(emitter_count);
87    emitters.extend(inventory::iter::<AotEmitter>);
88    emitters
89}
90
91/// Return every linked launcher emitter.
92#[must_use]
93pub fn registered_aot_launcher_emitters() -> Vec<&'static AotLauncherEmitter> {
94    let emitter_count = inventory::iter::<AotLauncherEmitter>.into_iter().count();
95    let mut emitters = Vec::new();
96    let _ = emitters.try_reserve_exact(emitter_count);
97    emitters.extend(inventory::iter::<AotLauncherEmitter>);
98    emitters
99}
100
101/// Emit target-native bytes through the linked emitter matching `target`.
102///
103/// # Errors
104///
105/// Returns [`BackendError::UnsupportedFeature`] when no linked backend owns
106/// `target`, or [`BackendError::KernelCompileFailed`] when the concrete
107/// emitter rejects the program.
108pub fn emit_aot_target(
109    target: &str,
110    program: &Program,
111    config: &DispatchConfig,
112) -> Result<Vec<u8>, BackendError> {
113    let Some(emitter) = inventory::iter::<AotEmitter>
114        .into_iter()
115        .find(|emitter| emitter.target == target)
116    else {
117        return Err(BackendError::UnsupportedFeature {
118            name: format!("aot target `{target}`"),
119            backend: "vyre-driver".to_string(),
120        });
121    };
122    (emitter.emit)(program, config).map_err(|compiler_message| BackendError::KernelCompileFailed {
123        backend: target.to_string(),
124        compiler_message,
125    })
126}
127
128/// Emit target-owned launcher files through the linked emitter matching `target`.
129///
130/// # Errors
131///
132/// Returns [`BackendError::UnsupportedFeature`] when no linked backend owns
133/// launcher generation for `target`, or [`BackendError::KernelCompileFailed`]
134/// when the concrete launcher emitter rejects the request.
135pub fn emit_aot_launcher_target(
136    target: &str,
137    request: &AotLauncherRequest<'_>,
138) -> Result<AotLauncherFiles, BackendError> {
139    let Some(emitter) = inventory::iter::<AotLauncherEmitter>
140        .into_iter()
141        .find(|emitter| emitter.target == target)
142    else {
143        return Err(BackendError::UnsupportedFeature {
144            name: format!("aot launcher target `{target}`"),
145            backend: "vyre-driver".to_string(),
146        });
147    };
148    (emitter.emit)(request).map_err(|compiler_message| BackendError::KernelCompileFailed {
149        backend: target.to_string(),
150        compiler_message,
151    })
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[test]
159    fn launcher_files_constructor_centralizes_path_keyed_container_assembly() {
160        let files = AotLauncherFiles::from_entries(
161            vec![LauncherDependency {
162                name: "libc",
163                spec: "\"0.2\"",
164            }],
165            [
166                (PathBuf::from("src/main.rs"), String::from("fn main() {}")),
167                (PathBuf::from("src/cuda_ffi.rs"), String::from("mod ffi {}")),
168            ],
169        );
170
171        assert_eq!(files.dependencies.len(), 1);
172        assert_eq!(files.files.len(), 2);
173        assert_eq!(
174            files.files[&PathBuf::from("src/main.rs")],
175            "fn main() {}",
176            "Fix: launcher file construction must preserve emitted file contents while centralizing the map-shaped public API."
177        );
178    }
179}