Skip to main content

shape_vm/
linker.rs

1//! Linking pass: converts a content-addressed `Program` into a flat `LinkedProgram`.
2//!
3//! The linker topologically sorts function blobs by their dependency edges,
4//! then flattens per-blob instruction/constant/string pools into merged arrays,
5//! remapping operand indices so they reference the correct global positions.
6
7use std::collections::HashMap;
8
9use rayon::prelude::*;
10
11use crate::bytecode::{
12    BytecodeProgram, Constant, DebugInfo, Function, FunctionBlob, FunctionHash, Instruction,
13    LinkedFunction, LinkedProgram, Operand, Program, SourceMap,
14};
15use shape_abi_v1::PermissionSet;
16use shape_value::{FunctionId, StringId};
17
18// ---------------------------------------------------------------------------
19// Error type
20// ---------------------------------------------------------------------------
21
22#[derive(Debug, thiserror::Error)]
23pub enum LinkError {
24    #[error("Missing function blob: {0}")]
25    MissingBlob(FunctionHash),
26    #[error("Circular dependency detected")]
27    CircularDependency,
28    #[error("Constant pool overflow: {0} constants exceeds u16 max")]
29    ConstantPoolOverflow(usize),
30    #[error("String pool overflow: {0} strings exceeds u32 max")]
31    StringPoolOverflow(usize),
32}
33
34// ---------------------------------------------------------------------------
35// Topological sort
36// ---------------------------------------------------------------------------
37
38/// Topologically sort blobs so that every dependency appears before
39/// the blob that depends on it.  Returns blob hashes in dependency order
40/// (leaves first, entry last).
41fn topo_sort(program: &Program) -> Result<Vec<FunctionHash>, LinkError> {
42    // States: 0 = unvisited, 1 = in-progress, 2 = done
43    let mut state: HashMap<FunctionHash, u8> = HashMap::new();
44    let mut order: Vec<FunctionHash> = Vec::with_capacity(program.function_store.len());
45
46    fn visit(
47        hash: FunctionHash,
48        program: &Program,
49        state: &mut HashMap<FunctionHash, u8>,
50        order: &mut Vec<FunctionHash>,
51    ) -> Result<(), LinkError> {
52        match state.get(&hash).copied().unwrap_or(0) {
53            2 => return Ok(()), // already done
54            1 => return Err(LinkError::CircularDependency),
55            _ => {}
56        }
57        state.insert(hash, 1); // mark in-progress
58
59        let blob = program
60            .function_store
61            .get(&hash)
62            .ok_or(LinkError::MissingBlob(hash))?;
63
64        for dep in &blob.dependencies {
65            // ZERO is an explicit self-recursion sentinel produced by the compiler.
66            // It does not reference a separate blob in function_store.
67            if *dep == FunctionHash::ZERO {
68                continue;
69            }
70            visit(*dep, program, state, order)?;
71        }
72
73        state.insert(hash, 2); // done
74        order.push(hash);
75        Ok(())
76    }
77
78    // Visit all blobs reachable from the entry point.  We start from entry
79    // so unreachable blobs are excluded (they could be present in the store
80    // from incremental compilation).
81    visit(program.entry, program, &mut state, &mut order)?;
82
83    // Also visit any remaining blobs not reachable from entry
84    // (e.g. blobs referenced only from constants or other metadata).
85    let remaining: Vec<FunctionHash> = program
86        .function_store
87        .keys()
88        .copied()
89        .filter(|h| state.get(h).copied().unwrap_or(0) != 2)
90        .collect();
91    for hash in remaining {
92        visit(hash, program, &mut state, &mut order)?;
93    }
94
95    Ok(order)
96}
97
98// ---------------------------------------------------------------------------
99// Operand remapping
100// ---------------------------------------------------------------------------
101
102/// Remap a single operand given the per-blob base offsets and function
103/// hash-to-id mapping.
104fn remap_operand(
105    operand: Operand,
106    const_base: usize,
107    string_base: usize,
108    blob: &FunctionBlob,
109    current_function_id: usize,
110    hash_to_id: &HashMap<FunctionHash, usize>,
111    name_to_id: &HashMap<&str, usize>,
112) -> Operand {
113    match operand {
114        Operand::Const(i) => Operand::Const((const_base + i as usize) as u16),
115        Operand::Property(i) => Operand::Property((string_base + i as usize) as u16),
116        Operand::Name(StringId(i)) => Operand::Name(StringId((string_base + i as usize) as u32)),
117        Operand::Function(FunctionId(dep_idx)) => {
118            if let Some(dep_hash) = blob.dependencies.get(dep_idx as usize) {
119                if *dep_hash == FunctionHash::ZERO {
120                    // ZERO sentinel: self-recursion or mutual recursion.
121                    // Check callee_names to determine which function to target.
122                    if let Some(callee_name) = blob.callee_names.get(dep_idx as usize) {
123                        if callee_name != &blob.name {
124                            // Mutual recursion: look up the target by name.
125                            if let Some(target_id) = name_to_id.get(callee_name.as_str()) {
126                                Operand::Function(FunctionId(*target_id as u16))
127                            } else {
128                                // Fallback: self (shouldn't happen for valid programs).
129                                Operand::Function(FunctionId(current_function_id as u16))
130                            }
131                        } else {
132                            // Self-recursion.
133                            Operand::Function(FunctionId(current_function_id as u16))
134                        }
135                    } else {
136                        // No callee name info; assume self-recursion.
137                        Operand::Function(FunctionId(current_function_id as u16))
138                    }
139                } else {
140                    let linked_id = hash_to_id[dep_hash];
141                    Operand::Function(FunctionId(linked_id as u16))
142                }
143            } else {
144                // Defensive fallback for blobs emitted with already-global function ids.
145                Operand::Function(FunctionId(dep_idx))
146            }
147        }
148        Operand::MethodCall { name, arg_count } => Operand::MethodCall {
149            name: StringId((string_base + name.0 as usize) as u32),
150            arg_count,
151        },
152        Operand::TypedMethodCall {
153            method_id,
154            arg_count,
155            string_id,
156        } => Operand::TypedMethodCall {
157            method_id,
158            arg_count,
159            string_id: (string_base + string_id as usize) as u16,
160        },
161        // Unchanged operands:
162        Operand::Offset(_)
163        | Operand::Local(_)
164        | Operand::ModuleBinding(_)
165        | Operand::Builtin(_)
166        | Operand::Count(_)
167        | Operand::ColumnIndex(_)
168        | Operand::TypedField { .. }
169        | Operand::TypedObjectAlloc { .. }
170        | Operand::TypedMerge { .. }
171        | Operand::ColumnAccess { .. }
172        | Operand::ForeignFunction(_)
173        | Operand::MatrixDims { .. }
174        | Operand::Width(_)
175        | Operand::TypedLocal(_, _)
176        | Operand::TypedModuleBinding(_, _) => operand,
177    }
178}
179
180// ---------------------------------------------------------------------------
181// Constant remapping
182// ---------------------------------------------------------------------------
183
184/// Remap function references inside `Constant::Function(idx)`.
185/// These reference dependency indices within the blob, not global function IDs,
186/// so they need the same treatment as `Operand::Function`.
187fn remap_constant(
188    constant: &Constant,
189    blob: &FunctionBlob,
190    current_function_id: usize,
191    hash_to_id: &HashMap<FunctionHash, usize>,
192    name_to_id: &HashMap<&str, usize>,
193) -> Constant {
194    match constant {
195        Constant::Function(dep_idx) => {
196            let dep_idx = *dep_idx as usize;
197            if dep_idx < blob.dependencies.len() {
198                let dep_hash = blob.dependencies[dep_idx];
199                if dep_hash == FunctionHash::ZERO {
200                    // ZERO sentinel: self-recursion or mutual recursion.
201                    if let Some(callee_name) = blob.callee_names.get(dep_idx) {
202                        if callee_name != &blob.name {
203                            // Mutual recursion: look up the target by name.
204                            if let Some(target_id) = name_to_id.get(callee_name.as_str()) {
205                                Constant::Function(*target_id as u16)
206                            } else {
207                                Constant::Function(current_function_id as u16)
208                            }
209                        } else {
210                            Constant::Function(current_function_id as u16)
211                        }
212                    } else {
213                        Constant::Function(current_function_id as u16)
214                    }
215                } else {
216                    let linked_id = hash_to_id[&dep_hash];
217                    Constant::Function(linked_id as u16)
218                }
219            } else {
220                // dep_idx doesn't map to a dependency — keep as-is.
221                constant.clone()
222            }
223        }
224        other => other.clone(),
225    }
226}
227
228// ---------------------------------------------------------------------------
229// Public API: link
230// ---------------------------------------------------------------------------
231
232/// Threshold for switching from sequential to parallel remap.
233/// Below this count, the overhead of Rayon's thread pool is not worth it.
234const PARALLEL_THRESHOLD: usize = 50;
235
236/// Per-blob offset information computed in Pass 1.
237struct BlobOffsets {
238    instruction_base: usize,
239    const_base: usize,
240    string_base: usize,
241}
242
243/// Link a content-addressed `Program` into a flat `LinkedProgram`.
244///
245/// The linker:
246/// 1. Topologically sorts function blobs by dependencies.
247/// 2. **Pass 1 (sequential):** Computes cumulative base offsets for each blob
248///    and builds the `hash_to_id` reverse index.
249/// 3. **Pass 2 (parallel for >50 functions):** Each blob independently remaps
250///    its instructions/constants/strings into pre-allocated output arrays at
251///    non-overlapping offsets.
252/// 4. Builds a `LinkedFunction` table and merged debug info.
253pub fn link(program: &Program) -> Result<LinkedProgram, LinkError> {
254    let sorted = topo_sort(program)?;
255
256    // Resolve sorted hashes to blob references up-front.
257    let blobs: Vec<&FunctionBlob> = sorted
258        .iter()
259        .map(|h| {
260            program
261                .function_store
262                .get(h)
263                .ok_or(LinkError::MissingBlob(*h))
264        })
265        .collect::<Result<Vec<_>, _>>()?;
266
267    // ------------------------------------------------------------------
268    // Pass 1 (sequential): compute base offsets and hash_to_id
269    // ------------------------------------------------------------------
270    let mut offsets: Vec<BlobOffsets> = Vec::with_capacity(blobs.len());
271    let mut hash_to_id: HashMap<FunctionHash, usize> = HashMap::with_capacity(blobs.len());
272    let mut name_to_id: HashMap<&str, usize> = HashMap::with_capacity(blobs.len());
273
274    let mut total_instructions: usize = 0;
275    let mut total_constants: usize = 0;
276    let mut total_strings: usize = 0;
277
278    for (i, blob) in blobs.iter().enumerate() {
279        offsets.push(BlobOffsets {
280            instruction_base: total_instructions,
281            const_base: total_constants,
282            string_base: total_strings,
283        });
284        hash_to_id.insert(blob.content_hash, i);
285        name_to_id.insert(&blob.name, i);
286
287        total_instructions += blob.instructions.len();
288        total_constants += blob.constants.len();
289        total_strings += blob.strings.len();
290    }
291
292    // Overflow checks on totals.
293    if total_constants > u16::MAX as usize + 1 {
294        return Err(LinkError::ConstantPoolOverflow(total_constants));
295    }
296    if total_strings > u32::MAX as usize + 1 {
297        return Err(LinkError::StringPoolOverflow(total_strings));
298    }
299
300    // Compute transitive union of all required permissions across all blobs.
301    let total_required_permissions = blobs.iter().fold(PermissionSet::pure(), |acc, blob| {
302        acc.union(&blob.required_permissions)
303    });
304
305    // ------------------------------------------------------------------
306    // Pass 2: remap and write into pre-allocated arrays
307    // ------------------------------------------------------------------
308    let use_parallel = blobs.len() > PARALLEL_THRESHOLD;
309
310    // Pre-allocate output arrays with exact sizes.
311    let mut instructions: Vec<Instruction> = Vec::with_capacity(total_instructions);
312    let mut constants: Vec<Constant> = Vec::with_capacity(total_constants);
313    let mut strings: Vec<String> = Vec::with_capacity(total_strings);
314
315    if use_parallel {
316        // SAFETY: We write to non-overlapping regions of the output arrays.
317        // Each blob writes to [base..base+len) which is disjoint from all
318        // other blobs because the bases are cumulative sums of prior sizes.
319        // We use `set_len` after all writes to make the Vecs aware of the data.
320
321        // Extend vecs to their full capacity with uninitialized-safe defaults.
322        // For Instructions (Copy type), use zeroed memory via MaybeUninit logic.
323        // For Constant/String (non-Copy), we must use a different strategy:
324        // collect per-blob results in parallel, then write sequentially.
325
326        // Strategy: parallel map each blob to its (remapped_instructions,
327        // remapped_constants, cloned_strings, source_map_entries), then
328        // write them into the pre-allocated arrays sequentially (memcpy-fast).
329        struct BlobResult {
330            instructions: Vec<Instruction>,
331            constants: Vec<Constant>,
332            strings: Vec<String>,
333            source_map: Vec<(usize, u16, u32)>,
334        }
335
336        let results: Vec<BlobResult> = blobs
337            .par_iter()
338            .zip(offsets.par_iter())
339            .enumerate()
340            .map(|(function_id, (blob, off))| {
341                let remapped_instrs: Vec<Instruction> = blob
342                    .instructions
343                    .iter()
344                    .map(|instr| {
345                        let remapped_operand = instr.operand.map(|op| {
346                            remap_operand(
347                                op,
348                                off.const_base,
349                                off.string_base,
350                                blob,
351                                function_id,
352                                &hash_to_id,
353                                &name_to_id,
354                            )
355                        });
356                        Instruction {
357                            opcode: instr.opcode,
358                            operand: remapped_operand,
359                        }
360                    })
361                    .collect();
362
363                let remapped_consts: Vec<Constant> = blob
364                    .constants
365                    .iter()
366                    .map(|c| remap_constant(c, blob, function_id, &hash_to_id, &name_to_id))
367                    .collect();
368
369                let cloned_strings: Vec<String> = blob.strings.clone();
370
371                let source_entries: Vec<(usize, u16, u32)> = blob
372                    .source_map
373                    .iter()
374                    .map(|&(local_offset, file_id, line)| {
375                        (off.instruction_base + local_offset, file_id as u16, line)
376                    })
377                    .collect();
378
379                BlobResult {
380                    instructions: remapped_instrs,
381                    constants: remapped_consts,
382                    strings: cloned_strings,
383                    source_map: source_entries,
384                }
385            })
386            .collect();
387
388        // Now write results into the pre-allocated arrays (sequential, but
389        // this is just memcpy/move of contiguous data -- very fast).
390        let mut merged_line_numbers: Vec<(usize, u16, u32)> = Vec::new();
391        for result in results {
392            instructions.extend(result.instructions);
393            constants.extend(result.constants);
394            strings.extend(result.strings);
395            merged_line_numbers.extend(result.source_map);
396        }
397
398        merged_line_numbers.sort_by_key(|&(offset, _, _)| offset);
399
400        let functions: Vec<LinkedFunction> = blobs
401            .iter()
402            .zip(offsets.iter())
403            .map(|(blob, off)| LinkedFunction {
404                blob_hash: blob.content_hash,
405                entry_point: off.instruction_base,
406                body_length: blob.instructions.len(),
407                name: blob.name.clone(),
408                arity: blob.arity,
409                param_names: blob.param_names.clone(),
410                locals_count: blob.locals_count,
411                is_closure: blob.is_closure,
412                captures_count: blob.captures_count,
413                is_async: blob.is_async,
414                ref_params: blob.ref_params.clone(),
415                ref_mutates: blob.ref_mutates.clone(),
416                mutable_captures: blob.mutable_captures.clone(),
417                frame_descriptor: blob.frame_descriptor.clone(),
418            })
419            .collect();
420
421        let debug_info = DebugInfo {
422            source_map: SourceMap {
423                files: program.debug_info.source_map.files.clone(),
424                source_texts: program.debug_info.source_map.source_texts.clone(),
425            },
426            line_numbers: merged_line_numbers,
427            variable_names: program.debug_info.variable_names.clone(),
428            source_text: String::new(),
429        };
430
431        return Ok(LinkedProgram {
432            entry: program.entry,
433            instructions,
434            constants,
435            strings,
436            functions,
437            hash_to_id,
438            debug_info,
439            data_schema: program.data_schema.clone(),
440            module_binding_names: program.module_binding_names.clone(),
441            top_level_locals_count: program.top_level_locals_count,
442            top_level_local_storage_hints: program.top_level_local_storage_hints.clone(),
443            type_schema_registry: program.type_schema_registry.clone(),
444            module_binding_storage_hints: program.module_binding_storage_hints.clone(),
445            function_local_storage_hints: program.function_local_storage_hints.clone(),
446            top_level_frame: program.top_level_frame.clone(),
447            trait_method_symbols: program.trait_method_symbols.clone(),
448            foreign_functions: program.foreign_functions.clone(),
449            native_struct_layouts: program.native_struct_layouts.clone(),
450            total_required_permissions: total_required_permissions.clone(),
451        });
452    }
453
454    // ------------------------------------------------------------------
455    // Sequential path (≤ PARALLEL_THRESHOLD functions)
456    // ------------------------------------------------------------------
457    let mut merged_line_numbers: Vec<(usize, u16, u32)> = Vec::new();
458
459    for (function_id, (blob, off)) in blobs.iter().zip(offsets.iter()).enumerate() {
460        // Remap and copy instructions.
461        for instr in &blob.instructions {
462            let remapped_operand = instr.operand.map(|op| {
463                remap_operand(
464                    op,
465                    off.const_base,
466                    off.string_base,
467                    blob,
468                    function_id,
469                    &hash_to_id,
470                    &name_to_id,
471                )
472            });
473            instructions.push(Instruction {
474                opcode: instr.opcode,
475                operand: remapped_operand,
476            });
477        }
478
479        // Merge constants (remap Constant::Function).
480        for c in &blob.constants {
481            constants.push(remap_constant(
482                c,
483                blob,
484                function_id,
485                &hash_to_id,
486                &name_to_id,
487            ));
488        }
489
490        // Merge strings.
491        strings.extend(blob.strings.iter().cloned());
492
493        // Merge source map entries.
494        for &(local_offset, file_id, line) in &blob.source_map {
495            let global_offset = off.instruction_base + local_offset;
496            merged_line_numbers.push((global_offset, file_id as u16, line));
497        }
498    }
499
500    // Sort line numbers by instruction offset for correct binary-search lookup.
501    merged_line_numbers.sort_by_key(|&(offset, _, _)| offset);
502
503    let functions: Vec<LinkedFunction> = blobs
504        .iter()
505        .zip(offsets.iter())
506        .map(|(blob, off)| LinkedFunction {
507            blob_hash: blob.content_hash,
508            entry_point: off.instruction_base,
509            body_length: blob.instructions.len(),
510            name: blob.name.clone(),
511            arity: blob.arity,
512            param_names: blob.param_names.clone(),
513            locals_count: blob.locals_count,
514            is_closure: blob.is_closure,
515            captures_count: blob.captures_count,
516            is_async: blob.is_async,
517            ref_params: blob.ref_params.clone(),
518            ref_mutates: blob.ref_mutates.clone(),
519            mutable_captures: blob.mutable_captures.clone(),
520            frame_descriptor: blob.frame_descriptor.clone(),
521        })
522        .collect();
523
524    let debug_info = DebugInfo {
525        source_map: SourceMap {
526            files: program.debug_info.source_map.files.clone(),
527            source_texts: program.debug_info.source_map.source_texts.clone(),
528        },
529        line_numbers: merged_line_numbers,
530        variable_names: program.debug_info.variable_names.clone(),
531        source_text: String::new(),
532    };
533
534    Ok(LinkedProgram {
535        entry: program.entry,
536        instructions,
537        constants,
538        strings,
539        functions,
540        hash_to_id,
541        debug_info,
542        data_schema: program.data_schema.clone(),
543        module_binding_names: program.module_binding_names.clone(),
544        top_level_locals_count: program.top_level_locals_count,
545        top_level_local_storage_hints: program.top_level_local_storage_hints.clone(),
546        type_schema_registry: program.type_schema_registry.clone(),
547        module_binding_storage_hints: program.module_binding_storage_hints.clone(),
548        function_local_storage_hints: program.function_local_storage_hints.clone(),
549        top_level_frame: program.top_level_frame.clone(),
550        trait_method_symbols: program.trait_method_symbols.clone(),
551        foreign_functions: program.foreign_functions.clone(),
552        native_struct_layouts: program.native_struct_layouts.clone(),
553        total_required_permissions,
554    })
555}
556
557// ---------------------------------------------------------------------------
558// Public API: linked_to_bytecode_program
559// ---------------------------------------------------------------------------
560
561/// Convert a `LinkedProgram` back to the legacy `BytecodeProgram` format
562/// for backward compatibility with the existing VM executor.
563pub fn linked_to_bytecode_program(linked: &LinkedProgram) -> BytecodeProgram {
564    let functions: Vec<Function> = linked
565        .functions
566        .iter()
567        .map(|lf| Function {
568            name: lf.name.clone(),
569            arity: lf.arity,
570            param_names: lf.param_names.clone(),
571            locals_count: lf.locals_count,
572            entry_point: lf.entry_point,
573            body_length: lf.body_length,
574            is_closure: lf.is_closure,
575            captures_count: lf.captures_count,
576            is_async: lf.is_async,
577            ref_params: lf.ref_params.clone(),
578            ref_mutates: lf.ref_mutates.clone(),
579            mutable_captures: lf.mutable_captures.clone(),
580            frame_descriptor: lf.frame_descriptor.clone(),
581            osr_entry_points: Vec::new(),
582        })
583        .collect();
584
585    BytecodeProgram {
586        instructions: linked.instructions.clone(),
587        constants: linked.constants.clone(),
588        strings: linked.strings.clone(),
589        functions,
590        debug_info: linked.debug_info.clone(),
591        data_schema: linked.data_schema.clone(),
592        module_binding_names: linked.module_binding_names.clone(),
593        top_level_locals_count: linked.top_level_locals_count,
594        top_level_local_storage_hints: linked.top_level_local_storage_hints.clone(),
595        type_schema_registry: linked.type_schema_registry.clone(),
596        module_binding_storage_hints: linked.module_binding_storage_hints.clone(),
597        function_local_storage_hints: linked.function_local_storage_hints.clone(),
598        top_level_frame: linked.top_level_frame.clone(),
599        compiled_annotations: HashMap::new(),
600        trait_method_symbols: linked.trait_method_symbols.clone(),
601        expanded_function_defs: HashMap::new(),
602        string_index: HashMap::new(),
603        foreign_functions: linked.foreign_functions.clone(),
604        native_struct_layouts: linked.native_struct_layouts.clone(),
605        content_addressed: None,
606        function_blob_hashes: linked
607            .functions
608            .iter()
609            .map(|lf| {
610                if lf.blob_hash == FunctionHash::ZERO {
611                    None
612                } else {
613                    Some(lf.blob_hash)
614                }
615            })
616            .collect(),
617    }
618}
619
620// ---------------------------------------------------------------------------
621// Tests
622// ---------------------------------------------------------------------------
623
624#[cfg(test)]
625#[path = "linker_tests.rs"]
626mod tests;