Skip to main content

rusty_javac/classfile/
writer.rs

1use rust_asm::class_reader::{AttributeInfo, LocalVariable, read_class_file};
2use rust_asm::class_writer::{
3    COMPUTE_FRAMES, COMPUTE_MAXS, ClassWriter as AsmClassWriter, FieldVisitor, MethodVisitor,
4};
5use rust_asm::constant_pool::{ConstantPoolBuilder, CpInfo};
6pub use rust_asm::insn::Label;
7use rust_asm::insn::LabelNode;
8pub use rust_asm::insn::{BootstrapArgument, Handle};
9use std::collections::HashMap;
10
11pub struct ClassFileWriter {
12    cw: AsmClassWriter,
13    class_signature: Option<String>,
14    nest_host: Option<String>,
15    nest_members: Vec<String>,
16    record_components: Vec<RecordComponentMetadata>,
17    runtime_invisible_annotations: Vec<AnnotationMetadata>,
18    method_metadata: Vec<MethodMetadata>,
19    field_metadata: Vec<FieldMetadata>,
20}
21
22impl ClassFileWriter {
23    pub fn new() -> Self {
24        Self {
25            cw: AsmClassWriter::new(COMPUTE_FRAMES | COMPUTE_MAXS),
26            class_signature: None,
27            nest_host: None,
28            nest_members: Vec::new(),
29            record_components: Vec::new(),
30            runtime_invisible_annotations: Vec::new(),
31            method_metadata: Vec::new(),
32            field_metadata: Vec::new(),
33        }
34    }
35
36    pub fn visit(
37        &mut self,
38        major_version: u16,
39        access_flags: u16,
40        name: &str,
41        super_name: Option<&str>,
42        interfaces: &[&str],
43    ) {
44        self.cw
45            .visit(major_version, 0, access_flags, name, super_name, interfaces);
46    }
47
48    pub fn visit_method(
49        &mut self,
50        access_flags: u16,
51        name: &str,
52        descriptor: &str,
53    ) -> MethodWriter {
54        let mv = self.cw.visit_method(access_flags, name, descriptor);
55        MethodWriter {
56            inner: mv,
57            name: name.to_string(),
58            descriptor: descriptor.to_string(),
59            signature: None,
60            exceptions: Vec::new(),
61            local_variables: Vec::new(),
62        }
63    }
64
65    pub fn visit_field(&mut self, access_flags: u16, name: &str, descriptor: &str) -> FieldWriter {
66        let fv = self.cw.visit_field(access_flags, name, descriptor);
67        FieldWriter {
68            inner: fv,
69            name: name.to_string(),
70            descriptor: descriptor.to_string(),
71            signature: None,
72        }
73    }
74
75    pub fn visit_signature(&mut self, signature: &str) {
76        self.class_signature = Some(signature.to_string());
77    }
78
79    pub fn visit_source_file(&mut self, name: &str) {
80        self.cw.visit_source_file(name);
81    }
82
83    pub fn visit_nest_host(&mut self, host: &str) {
84        self.nest_host = Some(host.to_string());
85    }
86
87    pub fn visit_nest_member(&mut self, member: &str) {
88        self.nest_members.push(member.to_string());
89    }
90
91    pub fn visit_record_component(
92        &mut self,
93        name: &str,
94        descriptor: &str,
95        signature: Option<&str>,
96    ) {
97        self.record_components.push(RecordComponentMetadata {
98            name: name.to_string(),
99            descriptor: descriptor.to_string(),
100            signature: signature.map(str::to_string),
101        });
102    }
103
104    pub fn visit_runtime_invisible_annotation(&mut self, annotation: AnnotationMetadata) {
105        self.runtime_invisible_annotations.push(annotation);
106    }
107
108    pub fn to_bytes(self) -> Result<Vec<u8>, String> {
109        let mut class_node = self.cw.to_class_node().map_err(|e| e.to_string())?;
110        add_extra_attributes(
111            &mut class_node,
112            ExtraAttributes {
113                class_signature: self.class_signature.as_deref(),
114                nest_host: self.nest_host.as_deref(),
115                nest_members: &self.nest_members,
116                record_components: &self.record_components,
117                runtime_invisible_annotations: &self.runtime_invisible_annotations,
118                field_metadata: &self.field_metadata,
119                method_metadata: &self.method_metadata,
120            },
121        );
122
123        let first_pass =
124            AsmClassWriter::write_class_node(&class_node, COMPUTE_FRAMES | COMPUTE_MAXS)
125                .map_err(|e| format!("{:?}", e))?;
126        let code_lengths = method_code_lengths(&first_pass)?;
127        add_local_variables(&mut class_node, &self.method_metadata, &code_lengths);
128
129        AsmClassWriter::write_class_node(&class_node, COMPUTE_FRAMES | COMPUTE_MAXS)
130            .map_err(|e| format!("{:?}", e))
131    }
132}
133
134impl Default for ClassFileWriter {
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140pub struct MethodWriter {
141    inner: MethodVisitor,
142    name: String,
143    descriptor: String,
144    signature: Option<String>,
145    exceptions: Vec<String>,
146    local_variables: Vec<LocalVariableSpec>,
147}
148
149impl MethodWriter {
150    pub fn visit_code(&mut self) {
151        self.inner.visit_code();
152    }
153
154    pub fn visit_insn(&mut self, opcode: u8) {
155        self.inner.visit_insn(opcode);
156    }
157
158    pub fn visit_var_insn(&mut self, opcode: u8, var_index: u16) {
159        self.inner.visit_var_insn(opcode, var_index);
160    }
161
162    pub fn visit_type_insn(&mut self, opcode: u8, type_name: &str) {
163        self.inner.visit_type_insn(opcode, type_name);
164    }
165
166    pub fn visit_new_array(&mut self, array_type: u8) {
167        self.inner
168            .visit_var_insn(rust_asm::opcodes::NEWARRAY, array_type as u16);
169    }
170
171    pub fn visit_jump_insn(&mut self, opcode: u8, target: Label) {
172        self.inner.visit_jump_insn(opcode, target);
173    }
174
175    pub fn visit_lookup_switch(&mut self, default: Label, pairs: &[(i32, Label)]) {
176        self.inner.visit_lookup_switch(default, pairs);
177    }
178
179    pub fn visit_label(&mut self, label: Label) {
180        self.inner.visit_label(label);
181    }
182
183    pub fn visit_line_number(&mut self, line: u16, label: Label) {
184        self.inner
185            .visit_line_number(line, LabelNode::from_label(label));
186    }
187
188    pub fn visit_try_catch_block(
189        &mut self,
190        start: Label,
191        end: Label,
192        handler: Label,
193        catch_type: Option<&str>,
194    ) {
195        self.inner
196            .visit_try_catch_block(start, end, handler, catch_type);
197    }
198
199    pub fn visit_local_variable(&mut self, name: &str, descriptor: &str, index: u16) {
200        self.local_variables.push(LocalVariableSpec {
201            name: name.to_string(),
202            descriptor: descriptor.to_string(),
203            index,
204        });
205    }
206
207    pub fn visit_signature(&mut self, signature: &str) {
208        self.signature = Some(signature.to_string());
209    }
210
211    pub fn visit_exception(&mut self, internal_name: &str) {
212        self.exceptions.push(internal_name.to_string());
213    }
214
215    pub fn visit_field_insn(&mut self, opcode: u8, owner: &str, name: &str, descriptor: &str) {
216        self.inner.visit_field_insn(opcode, owner, name, descriptor);
217    }
218
219    pub fn visit_method_insn(
220        &mut self,
221        opcode: u8,
222        owner: &str,
223        name: &str,
224        descriptor: &str,
225        is_interface: bool,
226    ) {
227        self.inner
228            .visit_method_insn(opcode, owner, name, descriptor, is_interface);
229    }
230
231    pub fn visit_invoke_dynamic_insn(
232        &mut self,
233        name: &str,
234        descriptor: &str,
235        bootstrap_method: Handle,
236        bootstrap_args: &[BootstrapArgument],
237    ) {
238        self.inner
239            .visit_invokedynamic_insn(name, descriptor, bootstrap_method, bootstrap_args);
240    }
241
242    pub fn visit_ldc_insn_int(&mut self, value: i32) {
243        self.inner
244            .visit_ldc_insn(rust_asm::insn::LdcInsnNode::int(value));
245    }
246
247    pub fn visit_ldc_insn_float(&mut self, value: f32) {
248        self.inner
249            .visit_ldc_insn(rust_asm::insn::LdcInsnNode::float(value));
250    }
251
252    pub fn visit_ldc_insn_long(&mut self, value: i64) {
253        self.inner
254            .visit_ldc_insn(rust_asm::insn::LdcInsnNode::long(value));
255    }
256
257    pub fn visit_ldc_insn_double(&mut self, value: f64) {
258        self.inner
259            .visit_ldc_insn(rust_asm::insn::LdcInsnNode::double(value));
260    }
261
262    pub fn visit_ldc_insn_string(&mut self, value: &str) {
263        self.inner
264            .visit_ldc_insn(rust_asm::insn::LdcInsnNode::string(value));
265    }
266
267    pub fn visit_ldc_insn_type(&mut self, type_name: &str) {
268        self.inner
269            .visit_ldc_insn(rust_asm::insn::LdcInsnNode::typed(
270                rust_asm::types::Type::get_object_type(type_name),
271            ));
272    }
273
274    pub fn visit_iinc_insn(&mut self, var_index: u16, increment: i16) {
275        self.inner.visit_iinc_insn(var_index, increment);
276    }
277
278    pub fn visit_maxs(&mut self, max_stack: u16, max_locals: u16) {
279        self.inner.visit_maxs(max_stack, max_locals);
280    }
281
282    pub fn visit_end(self, cw: &mut ClassFileWriter) {
283        cw.method_metadata.push(MethodMetadata {
284            name: self.name.clone(),
285            descriptor: self.descriptor.clone(),
286            signature: self.signature.clone(),
287            exceptions: self.exceptions.clone(),
288            local_variables: self.local_variables.clone(),
289        });
290        self.inner.visit_end(&mut cw.cw);
291    }
292}
293
294pub struct FieldWriter {
295    inner: FieldVisitor,
296    name: String,
297    descriptor: String,
298    signature: Option<String>,
299}
300
301impl FieldWriter {
302    pub fn visit_signature(&mut self, signature: &str) {
303        self.signature = Some(signature.to_string());
304    }
305
306    pub fn visit_end(self, cw: &mut ClassFileWriter) {
307        cw.field_metadata.push(FieldMetadata {
308            name: self.name.clone(),
309            descriptor: self.descriptor.clone(),
310            signature: self.signature.clone(),
311        });
312        self.inner.visit_end(&mut cw.cw);
313    }
314}
315
316#[derive(Debug, Clone)]
317struct LocalVariableSpec {
318    name: String,
319    descriptor: String,
320    index: u16,
321}
322
323#[derive(Debug, Clone)]
324struct MethodMetadata {
325    name: String,
326    descriptor: String,
327    signature: Option<String>,
328    exceptions: Vec<String>,
329    local_variables: Vec<LocalVariableSpec>,
330}
331
332#[derive(Debug, Clone)]
333struct FieldMetadata {
334    name: String,
335    descriptor: String,
336    signature: Option<String>,
337}
338
339#[derive(Debug, Clone)]
340struct RecordComponentMetadata {
341    name: String,
342    descriptor: String,
343    signature: Option<String>,
344}
345
346#[derive(Debug, Clone)]
347pub struct AnnotationMetadata {
348    pub descriptor: String,
349    pub elements: Vec<AnnotationElementMetadata>,
350}
351
352#[derive(Debug, Clone)]
353pub struct AnnotationElementMetadata {
354    pub name: String,
355    pub value: AnnotationElementValueMetadata,
356}
357
358#[derive(Debug, Clone)]
359pub enum AnnotationElementValueMetadata {
360    String(String),
361    Int(i64),
362    Boolean(bool),
363}
364
365struct ExtraAttributes<'a> {
366    class_signature: Option<&'a str>,
367    nest_host: Option<&'a str>,
368    nest_members: &'a [String],
369    record_components: &'a [RecordComponentMetadata],
370    runtime_invisible_annotations: &'a [AnnotationMetadata],
371    field_metadata: &'a [FieldMetadata],
372    method_metadata: &'a [MethodMetadata],
373}
374
375fn add_extra_attributes(class_node: &mut rust_asm::nodes::ClassNode, extras: ExtraAttributes<'_>) {
376    let mut cp = ConstantPoolBuilder::from_pool(class_node.constant_pool.clone());
377    if extras.class_signature.is_some()
378        || extras
379            .field_metadata
380            .iter()
381            .any(|metadata| metadata.signature.is_some())
382        || extras
383            .method_metadata
384            .iter()
385            .any(|metadata| metadata.signature.is_some())
386    {
387        cp.utf8("Signature");
388    }
389    if extras
390        .method_metadata
391        .iter()
392        .any(|metadata| !metadata.exceptions.is_empty())
393    {
394        cp.utf8("Exceptions");
395    }
396    if extras.nest_host.is_some() {
397        cp.utf8("NestHost");
398    }
399    if !extras.nest_members.is_empty() {
400        cp.utf8("NestMembers");
401    }
402    if !extras.record_components.is_empty() {
403        cp.utf8("Record");
404    }
405    if !extras.runtime_invisible_annotations.is_empty() {
406        cp.utf8("RuntimeInvisibleAnnotations");
407    }
408    for metadata in extras.field_metadata {
409        cp.utf8(&metadata.name);
410        cp.utf8(&metadata.descriptor);
411    }
412    for component in extras.record_components {
413        cp.utf8(&component.name);
414        cp.utf8(&component.descriptor);
415        if let Some(signature) = &component.signature {
416            cp.utf8("Signature");
417            cp.utf8(signature);
418        }
419    }
420    for annotation in extras.runtime_invisible_annotations {
421        cp.utf8(&annotation.descriptor);
422        for element in &annotation.elements {
423            cp.utf8(&element.name);
424            match &element.value {
425                AnnotationElementValueMetadata::String(value) => {
426                    cp.utf8(value);
427                }
428                AnnotationElementValueMetadata::Int(value) => {
429                    cp.integer(*value as i32);
430                }
431                AnnotationElementValueMetadata::Boolean(value) => {
432                    cp.integer(i32::from(*value));
433                }
434            }
435        }
436    }
437    for metadata in extras.method_metadata {
438        cp.utf8(&metadata.name);
439        cp.utf8(&metadata.descriptor);
440    }
441
442    if let Some(signature) = extras.class_signature {
443        add_signature_attribute(&mut class_node.attributes, &mut cp, signature);
444    }
445    if let Some(host) = extras.nest_host {
446        add_nest_host_attribute(&mut class_node.attributes, &mut cp, host);
447    }
448    if !extras.nest_members.is_empty() {
449        add_nest_members_attribute(&mut class_node.attributes, &mut cp, extras.nest_members);
450    }
451    if !extras.record_components.is_empty() {
452        add_record_attribute(
453            &mut class_node.attributes,
454            &mut cp,
455            extras.record_components,
456        );
457    }
458    if !extras.runtime_invisible_annotations.is_empty() {
459        add_runtime_invisible_annotations_attribute(
460            &mut class_node.attributes,
461            &mut cp,
462            extras.runtime_invisible_annotations,
463        );
464    }
465
466    for (field, metadata) in class_node.fields.iter_mut().zip(extras.field_metadata) {
467        if field.name == metadata.name
468            && field.descriptor == metadata.descriptor
469            && let Some(signature) = metadata.signature.as_deref()
470        {
471            add_signature_attribute(&mut field.attributes, &mut cp, signature);
472        }
473    }
474
475    for (method, metadata) in class_node.methods.iter_mut().zip(extras.method_metadata) {
476        if method.name == metadata.name
477            && method.descriptor == metadata.descriptor
478            && let Some(signature) = metadata.signature.as_deref()
479        {
480            add_signature_attribute(&mut method.attributes, &mut cp, signature);
481        }
482        if method.name == metadata.name
483            && method.descriptor == metadata.descriptor
484            && !metadata.exceptions.is_empty()
485        {
486            add_exceptions_attribute(&mut method.attributes, &mut cp, &metadata.exceptions);
487        }
488    }
489
490    class_node.constant_pool = cp.into_pool();
491}
492
493fn add_runtime_invisible_annotations_attribute(
494    attributes: &mut Vec<AttributeInfo>,
495    cp: &mut ConstantPoolBuilder,
496    annotations: &[AnnotationMetadata],
497) {
498    attributes.retain(
499        |attr| !matches!(attr, AttributeInfo::Unknown { name, .. } if name == "RuntimeInvisibleAnnotations"),
500    );
501    let mut info = Vec::new();
502    info.extend_from_slice(&(annotations.len() as u16).to_be_bytes());
503    for annotation in annotations {
504        info.extend_from_slice(&cp.utf8(&annotation.descriptor).to_be_bytes());
505        info.extend_from_slice(&(annotation.elements.len() as u16).to_be_bytes());
506        for element in &annotation.elements {
507            info.extend_from_slice(&cp.utf8(&element.name).to_be_bytes());
508            write_annotation_value(&mut info, cp, &element.value);
509        }
510    }
511    attributes.push(AttributeInfo::Unknown {
512        name: "RuntimeInvisibleAnnotations".to_string(),
513        info,
514    });
515}
516
517fn write_annotation_value(
518    info: &mut Vec<u8>,
519    cp: &mut ConstantPoolBuilder,
520    value: &AnnotationElementValueMetadata,
521) {
522    match value {
523        AnnotationElementValueMetadata::String(value) => {
524            info.push(b's');
525            info.extend_from_slice(&cp.utf8(value).to_be_bytes());
526        }
527        AnnotationElementValueMetadata::Int(value) => {
528            info.push(b'I');
529            info.extend_from_slice(&cp.integer(*value as i32).to_be_bytes());
530        }
531        AnnotationElementValueMetadata::Boolean(value) => {
532            info.push(b'Z');
533            info.extend_from_slice(&cp.integer(i32::from(*value)).to_be_bytes());
534        }
535    }
536}
537
538fn add_record_attribute(
539    attributes: &mut Vec<AttributeInfo>,
540    cp: &mut ConstantPoolBuilder,
541    components: &[RecordComponentMetadata],
542) {
543    attributes
544        .retain(|attr| !matches!(attr, AttributeInfo::Unknown { name, .. } if name == "Record"));
545    let mut info = Vec::new();
546    info.extend_from_slice(&(components.len() as u16).to_be_bytes());
547    for component in components {
548        info.extend_from_slice(&cp.utf8(&component.name).to_be_bytes());
549        info.extend_from_slice(&cp.utf8(&component.descriptor).to_be_bytes());
550        if let Some(signature) = &component.signature {
551            info.extend_from_slice(&1u16.to_be_bytes());
552            info.extend_from_slice(&cp.utf8("Signature").to_be_bytes());
553            info.extend_from_slice(&2u32.to_be_bytes());
554            info.extend_from_slice(&cp.utf8(signature).to_be_bytes());
555        } else {
556            info.extend_from_slice(&0u16.to_be_bytes());
557        }
558    }
559    attributes.push(AttributeInfo::Unknown {
560        name: "Record".to_string(),
561        info,
562    });
563}
564
565fn add_nest_host_attribute(
566    attributes: &mut Vec<AttributeInfo>,
567    cp: &mut ConstantPoolBuilder,
568    host: &str,
569) {
570    attributes
571        .retain(|attr| !matches!(attr, AttributeInfo::Unknown { name, .. } if name == "NestHost"));
572    attributes.push(AttributeInfo::Unknown {
573        name: "NestHost".to_string(),
574        info: cp.class(host).to_be_bytes().to_vec(),
575    });
576}
577
578fn add_nest_members_attribute(
579    attributes: &mut Vec<AttributeInfo>,
580    cp: &mut ConstantPoolBuilder,
581    members: &[String],
582) {
583    attributes.retain(
584        |attr| !matches!(attr, AttributeInfo::Unknown { name, .. } if name == "NestMembers"),
585    );
586    let mut info = Vec::new();
587    info.extend_from_slice(&(members.len() as u16).to_be_bytes());
588    for member in members {
589        info.extend_from_slice(&cp.class(member).to_be_bytes());
590    }
591    attributes.push(AttributeInfo::Unknown {
592        name: "NestMembers".to_string(),
593        info,
594    });
595}
596
597fn add_signature_attribute(
598    attributes: &mut Vec<AttributeInfo>,
599    cp: &mut ConstantPoolBuilder,
600    signature: &str,
601) {
602    attributes.retain(|attr| !matches!(attr, AttributeInfo::Signature { .. }));
603    let signature_index = cp.utf8(signature);
604    attributes.push(AttributeInfo::Signature { signature_index });
605}
606
607fn add_exceptions_attribute(
608    attributes: &mut Vec<AttributeInfo>,
609    cp: &mut ConstantPoolBuilder,
610    exceptions: &[String],
611) {
612    attributes.retain(|attr| !matches!(attr, AttributeInfo::Exceptions { .. }));
613    let exception_index_table = exceptions
614        .iter()
615        .map(|exception| cp.class(exception))
616        .collect();
617    attributes.push(AttributeInfo::Exceptions {
618        exception_index_table,
619    });
620}
621
622fn add_local_variables(
623    class_node: &mut rust_asm::nodes::ClassNode,
624    method_metadata: &[MethodMetadata],
625    code_lengths: &HashMap<(String, String), u16>,
626) {
627    let mut cp = ConstantPoolBuilder::from_pool(class_node.constant_pool.clone());
628    cp.utf8("LocalVariableTable");
629
630    for (method, metadata) in class_node.methods.iter_mut().zip(method_metadata) {
631        if method.name != metadata.name || method.descriptor != metadata.descriptor {
632            continue;
633        }
634        let Some(length) = code_lengths
635            .get(&(metadata.name.clone(), metadata.descriptor.clone()))
636            .copied()
637            .filter(|length| *length > 0)
638        else {
639            continue;
640        };
641
642        method.local_variables.clear();
643        for variable in &metadata.local_variables {
644            method.local_variables.push(LocalVariable {
645                start_pc: 0,
646                length,
647                name_index: cp.utf8(&variable.name),
648                descriptor_index: cp.utf8(&variable.descriptor),
649                index: variable.index,
650            });
651        }
652    }
653
654    class_node.constant_pool = cp.into_pool();
655}
656
657fn method_code_lengths(bytes: &[u8]) -> Result<HashMap<(String, String), u16>, String> {
658    let class_file = read_class_file(bytes).map_err(|e| format!("{:?}", e))?;
659    let mut lengths = HashMap::new();
660
661    for method in class_file.methods {
662        let name = cp_utf8(&class_file.constant_pool, method.name_index)?.to_string();
663        let descriptor = cp_utf8(&class_file.constant_pool, method.descriptor_index)?.to_string();
664        let length = method
665            .attributes
666            .iter()
667            .find_map(|attr| match attr {
668                AttributeInfo::Code(code) => Some(code.code.len().min(u16::MAX as usize) as u16),
669                _ => None,
670            })
671            .unwrap_or(0);
672        lengths.insert((name, descriptor), length);
673    }
674
675    Ok(lengths)
676}
677
678fn cp_utf8(pool: &[CpInfo], index: u16) -> Result<&str, String> {
679    match pool.get(index as usize) {
680        Some(CpInfo::Utf8(value)) => Ok(value.as_str()),
681        _ => Err(format!("invalid UTF-8 constant pool index {index}")),
682    }
683}