1use std::collections::HashMap;
8
9use rayon::prelude::*;
10
11use crate::bytecode::{
12 BytecodeProgram, Constant, DebugInfo, Function, FunctionBlob, FunctionHash, Instruction,
13 LinkedFunction, LinkedProgram, Operand, Program, SourceMap,
14};
15use shape_abi_v1::PermissionSet;
16use shape_value::{FunctionId, StringId};
17
18#[derive(Debug, thiserror::Error)]
23pub enum LinkError {
24 #[error("Missing function blob: {0}")]
25 MissingBlob(FunctionHash),
26 #[error("Circular dependency detected")]
27 CircularDependency,
28 #[error("Constant pool overflow: {0} constants exceeds u16 max")]
29 ConstantPoolOverflow(usize),
30 #[error("String pool overflow: {0} strings exceeds u32 max")]
31 StringPoolOverflow(usize),
32}
33
34fn topo_sort(program: &Program) -> Result<Vec<FunctionHash>, LinkError> {
42 let mut state: HashMap<FunctionHash, u8> = HashMap::new();
44 let mut order: Vec<FunctionHash> = Vec::with_capacity(program.function_store.len());
45
46 fn visit(
47 hash: FunctionHash,
48 program: &Program,
49 state: &mut HashMap<FunctionHash, u8>,
50 order: &mut Vec<FunctionHash>,
51 ) -> Result<(), LinkError> {
52 match state.get(&hash).copied().unwrap_or(0) {
53 2 => return Ok(()), 1 => return Err(LinkError::CircularDependency),
55 _ => {}
56 }
57 state.insert(hash, 1); let blob = program
60 .function_store
61 .get(&hash)
62 .ok_or(LinkError::MissingBlob(hash))?;
63
64 for dep in &blob.dependencies {
65 if *dep == FunctionHash::ZERO {
68 continue;
69 }
70 visit(*dep, program, state, order)?;
71 }
72
73 state.insert(hash, 2); order.push(hash);
75 Ok(())
76 }
77
78 visit(program.entry, program, &mut state, &mut order)?;
82
83 let remaining: Vec<FunctionHash> = program
86 .function_store
87 .keys()
88 .copied()
89 .filter(|h| state.get(h).copied().unwrap_or(0) != 2)
90 .collect();
91 for hash in remaining {
92 visit(hash, program, &mut state, &mut order)?;
93 }
94
95 Ok(order)
96}
97
98fn remap_operand(
105 operand: Operand,
106 const_base: usize,
107 string_base: usize,
108 blob: &FunctionBlob,
109 current_function_id: usize,
110 hash_to_id: &HashMap<FunctionHash, usize>,
111 name_to_id: &HashMap<&str, usize>,
112) -> Operand {
113 match operand {
114 Operand::Const(i) => Operand::Const((const_base + i as usize) as u16),
115 Operand::Property(i) => Operand::Property((string_base + i as usize) as u16),
116 Operand::Name(StringId(i)) => Operand::Name(StringId((string_base + i as usize) as u32)),
117 Operand::Function(FunctionId(dep_idx)) => {
118 if let Some(dep_hash) = blob.dependencies.get(dep_idx as usize) {
119 if *dep_hash == FunctionHash::ZERO {
120 if let Some(callee_name) = blob.callee_names.get(dep_idx as usize) {
123 if callee_name != &blob.name {
124 if let Some(target_id) = name_to_id.get(callee_name.as_str()) {
126 Operand::Function(FunctionId(*target_id as u16))
127 } else {
128 Operand::Function(FunctionId(current_function_id as u16))
130 }
131 } else {
132 Operand::Function(FunctionId(current_function_id as u16))
134 }
135 } else {
136 Operand::Function(FunctionId(current_function_id as u16))
138 }
139 } else {
140 let linked_id = hash_to_id[dep_hash];
141 Operand::Function(FunctionId(linked_id as u16))
142 }
143 } else {
144 Operand::Function(FunctionId(dep_idx))
146 }
147 }
148 Operand::MethodCall { name, arg_count } => Operand::MethodCall {
149 name: StringId((string_base + name.0 as usize) as u32),
150 arg_count,
151 },
152 Operand::TypedMethodCall {
153 method_id,
154 arg_count,
155 string_id,
156 } => Operand::TypedMethodCall {
157 method_id,
158 arg_count,
159 string_id: (string_base + string_id as usize) as u16,
160 },
161 Operand::Offset(_)
163 | Operand::Local(_)
164 | Operand::ModuleBinding(_)
165 | Operand::Builtin(_)
166 | Operand::Count(_)
167 | Operand::ColumnIndex(_)
168 | Operand::TypedField { .. }
169 | Operand::TypedObjectAlloc { .. }
170 | Operand::TypedMerge { .. }
171 | Operand::ColumnAccess { .. }
172 | Operand::ForeignFunction(_)
173 | Operand::MatrixDims { .. }
174 | Operand::Width(_)
175 | Operand::TypedLocal(_, _)
176 | Operand::TypedModuleBinding(_, _) => operand,
177 }
178}
179
180fn remap_constant(
188 constant: &Constant,
189 blob: &FunctionBlob,
190 current_function_id: usize,
191 hash_to_id: &HashMap<FunctionHash, usize>,
192 name_to_id: &HashMap<&str, usize>,
193) -> Constant {
194 match constant {
195 Constant::Function(dep_idx) => {
196 let dep_idx = *dep_idx as usize;
197 if dep_idx < blob.dependencies.len() {
198 let dep_hash = blob.dependencies[dep_idx];
199 if dep_hash == FunctionHash::ZERO {
200 if let Some(callee_name) = blob.callee_names.get(dep_idx) {
202 if callee_name != &blob.name {
203 if let Some(target_id) = name_to_id.get(callee_name.as_str()) {
205 Constant::Function(*target_id as u16)
206 } else {
207 Constant::Function(current_function_id as u16)
208 }
209 } else {
210 Constant::Function(current_function_id as u16)
211 }
212 } else {
213 Constant::Function(current_function_id as u16)
214 }
215 } else {
216 let linked_id = hash_to_id[&dep_hash];
217 Constant::Function(linked_id as u16)
218 }
219 } else {
220 constant.clone()
222 }
223 }
224 other => other.clone(),
225 }
226}
227
228const PARALLEL_THRESHOLD: usize = 50;
235
236struct BlobOffsets {
238 instruction_base: usize,
239 const_base: usize,
240 string_base: usize,
241}
242
243pub fn link(program: &Program) -> Result<LinkedProgram, LinkError> {
254 let sorted = topo_sort(program)?;
255
256 let blobs: Vec<&FunctionBlob> = sorted
258 .iter()
259 .map(|h| {
260 program
261 .function_store
262 .get(h)
263 .ok_or(LinkError::MissingBlob(*h))
264 })
265 .collect::<Result<Vec<_>, _>>()?;
266
267 let mut offsets: Vec<BlobOffsets> = Vec::with_capacity(blobs.len());
271 let mut hash_to_id: HashMap<FunctionHash, usize> = HashMap::with_capacity(blobs.len());
272 let mut name_to_id: HashMap<&str, usize> = HashMap::with_capacity(blobs.len());
273
274 let mut total_instructions: usize = 0;
275 let mut total_constants: usize = 0;
276 let mut total_strings: usize = 0;
277
278 for (i, blob) in blobs.iter().enumerate() {
279 offsets.push(BlobOffsets {
280 instruction_base: total_instructions,
281 const_base: total_constants,
282 string_base: total_strings,
283 });
284 hash_to_id.insert(blob.content_hash, i);
285 name_to_id.insert(&blob.name, i);
286
287 total_instructions += blob.instructions.len();
288 total_constants += blob.constants.len();
289 total_strings += blob.strings.len();
290 }
291
292 if total_constants > u16::MAX as usize + 1 {
294 return Err(LinkError::ConstantPoolOverflow(total_constants));
295 }
296 if total_strings > u32::MAX as usize + 1 {
297 return Err(LinkError::StringPoolOverflow(total_strings));
298 }
299
300 let total_required_permissions = blobs.iter().fold(PermissionSet::pure(), |acc, blob| {
302 acc.union(&blob.required_permissions)
303 });
304
305 let use_parallel = blobs.len() > PARALLEL_THRESHOLD;
309
310 let mut instructions: Vec<Instruction> = Vec::with_capacity(total_instructions);
312 let mut constants: Vec<Constant> = Vec::with_capacity(total_constants);
313 let mut strings: Vec<String> = Vec::with_capacity(total_strings);
314
315 if use_parallel {
316 struct BlobResult {
330 instructions: Vec<Instruction>,
331 constants: Vec<Constant>,
332 strings: Vec<String>,
333 source_map: Vec<(usize, u16, u32)>,
334 }
335
336 let results: Vec<BlobResult> = blobs
337 .par_iter()
338 .zip(offsets.par_iter())
339 .enumerate()
340 .map(|(function_id, (blob, off))| {
341 let remapped_instrs: Vec<Instruction> = blob
342 .instructions
343 .iter()
344 .map(|instr| {
345 let remapped_operand = instr.operand.map(|op| {
346 remap_operand(
347 op,
348 off.const_base,
349 off.string_base,
350 blob,
351 function_id,
352 &hash_to_id,
353 &name_to_id,
354 )
355 });
356 Instruction {
357 opcode: instr.opcode,
358 operand: remapped_operand,
359 }
360 })
361 .collect();
362
363 let remapped_consts: Vec<Constant> = blob
364 .constants
365 .iter()
366 .map(|c| remap_constant(c, blob, function_id, &hash_to_id, &name_to_id))
367 .collect();
368
369 let cloned_strings: Vec<String> = blob.strings.clone();
370
371 let source_entries: Vec<(usize, u16, u32)> = blob
372 .source_map
373 .iter()
374 .map(|&(local_offset, file_id, line)| {
375 (off.instruction_base + local_offset, file_id as u16, line)
376 })
377 .collect();
378
379 BlobResult {
380 instructions: remapped_instrs,
381 constants: remapped_consts,
382 strings: cloned_strings,
383 source_map: source_entries,
384 }
385 })
386 .collect();
387
388 let mut merged_line_numbers: Vec<(usize, u16, u32)> = Vec::new();
391 for result in results {
392 instructions.extend(result.instructions);
393 constants.extend(result.constants);
394 strings.extend(result.strings);
395 merged_line_numbers.extend(result.source_map);
396 }
397
398 merged_line_numbers.sort_by_key(|&(offset, _, _)| offset);
399
400 let functions: Vec<LinkedFunction> = blobs
401 .iter()
402 .zip(offsets.iter())
403 .map(|(blob, off)| LinkedFunction {
404 blob_hash: blob.content_hash,
405 entry_point: off.instruction_base,
406 body_length: blob.instructions.len(),
407 name: blob.name.clone(),
408 arity: blob.arity,
409 param_names: blob.param_names.clone(),
410 locals_count: blob.locals_count,
411 is_closure: blob.is_closure,
412 captures_count: blob.captures_count,
413 is_async: blob.is_async,
414 ref_params: blob.ref_params.clone(),
415 ref_mutates: blob.ref_mutates.clone(),
416 mutable_captures: blob.mutable_captures.clone(),
417 frame_descriptor: blob.frame_descriptor.clone(),
418 })
419 .collect();
420
421 let debug_info = DebugInfo {
422 source_map: SourceMap {
423 files: program.debug_info.source_map.files.clone(),
424 source_texts: program.debug_info.source_map.source_texts.clone(),
425 },
426 line_numbers: merged_line_numbers,
427 variable_names: program.debug_info.variable_names.clone(),
428 source_text: String::new(),
429 };
430
431 return Ok(LinkedProgram {
432 entry: program.entry,
433 instructions,
434 constants,
435 strings,
436 functions,
437 hash_to_id,
438 debug_info,
439 data_schema: program.data_schema.clone(),
440 module_binding_names: program.module_binding_names.clone(),
441 top_level_locals_count: program.top_level_locals_count,
442 top_level_local_storage_hints: program.top_level_local_storage_hints.clone(),
443 type_schema_registry: program.type_schema_registry.clone(),
444 module_binding_storage_hints: program.module_binding_storage_hints.clone(),
445 function_local_storage_hints: program.function_local_storage_hints.clone(),
446 top_level_frame: program.top_level_frame.clone(),
447 trait_method_symbols: program.trait_method_symbols.clone(),
448 foreign_functions: program.foreign_functions.clone(),
449 native_struct_layouts: program.native_struct_layouts.clone(),
450 total_required_permissions: total_required_permissions.clone(),
451 });
452 }
453
454 let mut merged_line_numbers: Vec<(usize, u16, u32)> = Vec::new();
458
459 for (function_id, (blob, off)) in blobs.iter().zip(offsets.iter()).enumerate() {
460 for instr in &blob.instructions {
462 let remapped_operand = instr.operand.map(|op| {
463 remap_operand(
464 op,
465 off.const_base,
466 off.string_base,
467 blob,
468 function_id,
469 &hash_to_id,
470 &name_to_id,
471 )
472 });
473 instructions.push(Instruction {
474 opcode: instr.opcode,
475 operand: remapped_operand,
476 });
477 }
478
479 for c in &blob.constants {
481 constants.push(remap_constant(
482 c,
483 blob,
484 function_id,
485 &hash_to_id,
486 &name_to_id,
487 ));
488 }
489
490 strings.extend(blob.strings.iter().cloned());
492
493 for &(local_offset, file_id, line) in &blob.source_map {
495 let global_offset = off.instruction_base + local_offset;
496 merged_line_numbers.push((global_offset, file_id as u16, line));
497 }
498 }
499
500 merged_line_numbers.sort_by_key(|&(offset, _, _)| offset);
502
503 let functions: Vec<LinkedFunction> = blobs
504 .iter()
505 .zip(offsets.iter())
506 .map(|(blob, off)| LinkedFunction {
507 blob_hash: blob.content_hash,
508 entry_point: off.instruction_base,
509 body_length: blob.instructions.len(),
510 name: blob.name.clone(),
511 arity: blob.arity,
512 param_names: blob.param_names.clone(),
513 locals_count: blob.locals_count,
514 is_closure: blob.is_closure,
515 captures_count: blob.captures_count,
516 is_async: blob.is_async,
517 ref_params: blob.ref_params.clone(),
518 ref_mutates: blob.ref_mutates.clone(),
519 mutable_captures: blob.mutable_captures.clone(),
520 frame_descriptor: blob.frame_descriptor.clone(),
521 })
522 .collect();
523
524 let debug_info = DebugInfo {
525 source_map: SourceMap {
526 files: program.debug_info.source_map.files.clone(),
527 source_texts: program.debug_info.source_map.source_texts.clone(),
528 },
529 line_numbers: merged_line_numbers,
530 variable_names: program.debug_info.variable_names.clone(),
531 source_text: String::new(),
532 };
533
534 Ok(LinkedProgram {
535 entry: program.entry,
536 instructions,
537 constants,
538 strings,
539 functions,
540 hash_to_id,
541 debug_info,
542 data_schema: program.data_schema.clone(),
543 module_binding_names: program.module_binding_names.clone(),
544 top_level_locals_count: program.top_level_locals_count,
545 top_level_local_storage_hints: program.top_level_local_storage_hints.clone(),
546 type_schema_registry: program.type_schema_registry.clone(),
547 module_binding_storage_hints: program.module_binding_storage_hints.clone(),
548 function_local_storage_hints: program.function_local_storage_hints.clone(),
549 top_level_frame: program.top_level_frame.clone(),
550 trait_method_symbols: program.trait_method_symbols.clone(),
551 foreign_functions: program.foreign_functions.clone(),
552 native_struct_layouts: program.native_struct_layouts.clone(),
553 total_required_permissions,
554 })
555}
556
557pub fn linked_to_bytecode_program(linked: &LinkedProgram) -> BytecodeProgram {
564 let functions: Vec<Function> = linked
565 .functions
566 .iter()
567 .map(|lf| Function {
568 name: lf.name.clone(),
569 arity: lf.arity,
570 param_names: lf.param_names.clone(),
571 locals_count: lf.locals_count,
572 entry_point: lf.entry_point,
573 body_length: lf.body_length,
574 is_closure: lf.is_closure,
575 captures_count: lf.captures_count,
576 is_async: lf.is_async,
577 ref_params: lf.ref_params.clone(),
578 ref_mutates: lf.ref_mutates.clone(),
579 mutable_captures: lf.mutable_captures.clone(),
580 frame_descriptor: lf.frame_descriptor.clone(),
581 osr_entry_points: Vec::new(),
582 })
583 .collect();
584
585 BytecodeProgram {
586 instructions: linked.instructions.clone(),
587 constants: linked.constants.clone(),
588 strings: linked.strings.clone(),
589 functions,
590 debug_info: linked.debug_info.clone(),
591 data_schema: linked.data_schema.clone(),
592 module_binding_names: linked.module_binding_names.clone(),
593 top_level_locals_count: linked.top_level_locals_count,
594 top_level_local_storage_hints: linked.top_level_local_storage_hints.clone(),
595 type_schema_registry: linked.type_schema_registry.clone(),
596 module_binding_storage_hints: linked.module_binding_storage_hints.clone(),
597 function_local_storage_hints: linked.function_local_storage_hints.clone(),
598 top_level_frame: linked.top_level_frame.clone(),
599 compiled_annotations: HashMap::new(),
600 trait_method_symbols: linked.trait_method_symbols.clone(),
601 expanded_function_defs: HashMap::new(),
602 string_index: HashMap::new(),
603 foreign_functions: linked.foreign_functions.clone(),
604 native_struct_layouts: linked.native_struct_layouts.clone(),
605 content_addressed: None,
606 function_blob_hashes: linked
607 .functions
608 .iter()
609 .map(|lf| {
610 if lf.blob_hash == FunctionHash::ZERO {
611 None
612 } else {
613 Some(lf.blob_hash)
614 }
615 })
616 .collect(),
617 }
618}
619
620#[cfg(test)]
625#[path = "linker_tests.rs"]
626mod tests;